From fed0a4451c0c15741bc2b9ea0d9fd515c1a9acaf Mon Sep 17 00:00:00 2001 From: Robert Haschke Date: Wed, 24 Mar 2021 00:48:35 +0100 Subject: [PATCH 1/2] OverrideTestPassRef: roundtrip test for references passed to overriden methods Add tests to ensure that reference arguments passed to Python-overriden methods (from C++ -> Python -> C++) are actually passed by reference (and not copied), which is the default return_value_policy for passing from C++ to Python. --- tests/test_virtual_functions.cpp | 91 +++++++++++++++++++++ tests/test_virtual_functions.py | 132 +++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) diff --git a/tests/test_virtual_functions.cpp b/tests/test_virtual_functions.cpp index 685d64a7ca..04ca254d70 100644 --- a/tests/test_virtual_functions.cpp +++ b/tests/test_virtual_functions.cpp @@ -459,6 +459,79 @@ template class PyD_Tpl : public PyC_Tpl { }; */ +/* Tests passing objects by reference to Python-derived class methods */ +struct Passenger { // object class passed around, recording its copy and move constructions + std::string mtxt; + Passenger() = default; + // on copy or move: keep old mtxt and augment operation as well as new pointer id + Passenger(const Passenger &other) { mtxt = other.mtxt + "Copy->" + std::to_string(id()); } + Passenger(Passenger &&other) { mtxt = other.mtxt + "Move->" + std::to_string(id()); } + uintptr_t id() const { return reinterpret_cast(this); } +}; +struct ReferencePassingTest { // virtual base class used to test reference passing + ReferencePassingTest() = default; + ReferencePassingTest(const ReferencePassingTest &) = default; + ReferencePassingTest(ReferencePassingTest &&) = default; + virtual ~ReferencePassingTest() = default; + // NOLINTNEXTLINE(clang-analyzer-core.StackAddrEscapeBase) + virtual uintptr_t pass_valu(Passenger obj) { return modify(obj); }; + virtual uintptr_t pass_mref(Passenger &obj) { return modify(obj); }; + virtual uintptr_t pass_mptr(Passenger *obj) { return modify(*obj); }; + virtual uintptr_t pass_cref(const Passenger &obj) { return modify(obj); }; + virtual uintptr_t pass_cptr(const Passenger *obj) { return modify(*obj); }; + uintptr_t modify(const Passenger &obj) { return obj.id(); } + uintptr_t modify(Passenger &obj) { + obj.mtxt.append("_MODIFIED"); + return obj.id(); + } +}; +struct PyReferencePassingTest : ReferencePassingTest { + using ReferencePassingTest::ReferencePassingTest; + uintptr_t pass_valu(Passenger obj) override { + PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_valu, obj); + } + uintptr_t pass_mref(Passenger &obj) override { + PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_mref, obj); + } + uintptr_t pass_cref(const Passenger &obj) override { + PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_cref, obj); + } + uintptr_t pass_mptr(Passenger *obj) override { + PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_mptr, obj); + } + uintptr_t pass_cptr(const Passenger *obj) override { + PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_cptr, obj); + } +}; + +std::string evaluate(const Passenger &orig, uintptr_t cycled) { + return orig.mtxt + (orig.id() == cycled ? "_REF" : "_COPY"); +} +// Functions triggering virtual-method calls from python-derived class (caller) +// Goal: modifications to Passenger happening in Python-code methods +// overriding the C++ virtual methods, should remain visible in C++. +// TODO: Find template magic to avoid this code duplication +std::string check_roundtrip_valu(ReferencePassingTest &caller) { + Passenger obj; + return evaluate(obj, caller.pass_valu(obj)); +} +std::string check_roundtrip_mref(ReferencePassingTest &caller) { + Passenger obj; + return evaluate(obj, caller.pass_mref(obj)); +} +std::string check_roundtrip_cref(ReferencePassingTest &caller) { + Passenger obj; + return evaluate(obj, caller.pass_cref(obj)); +} +std::string check_roundtrip_mptr(ReferencePassingTest &caller) { + Passenger obj; + return evaluate(obj, caller.pass_mptr(&obj)); +} +std::string check_roundtrip_cptr(ReferencePassingTest &caller) { + Passenger obj; + return evaluate(obj, caller.pass_cptr(&obj)); +} + void initialize_inherited_virtuals(py::module_ &m) { // test_inherited_virtuals @@ -495,4 +568,22 @@ void initialize_inherited_virtuals(py::module_ &m) { // Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7) m.def("test_gil", &test_gil); m.def("test_gil_from_thread", &test_gil_from_thread); + + py::class_(m, "Passenger") + .def_property_readonly("id", &Passenger::id) + .def_readwrite("mtxt", &Passenger::mtxt); + + py::class_(m, "ReferencePassingTest") + .def(py::init<>()) + .def("pass_valu", &ReferencePassingTest::pass_valu) + .def("pass_mref", &ReferencePassingTest::pass_mref) + .def("pass_cref", &ReferencePassingTest::pass_cref) + .def("pass_mptr", &ReferencePassingTest::pass_mptr) + .def("pass_cptr", &ReferencePassingTest::pass_cptr); + + m.def("check_roundtrip_valu", check_roundtrip_valu); + m.def("check_roundtrip_mref", check_roundtrip_mref); + m.def("check_roundtrip_cref", check_roundtrip_cref); + m.def("check_roundtrip_mptr", check_roundtrip_mptr); + m.def("check_roundtrip_cptr", check_roundtrip_cptr); }; diff --git a/tests/test_virtual_functions.py b/tests/test_virtual_functions.py index f7d3bd1e4b..80a9594a12 100644 --- a/tests/test_virtual_functions.py +++ b/tests/test_virtual_functions.py @@ -406,3 +406,135 @@ def test_issue_1454(): # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7) m.test_gil() m.test_gil_from_thread() + + +# Python class inheriting from C++ class ReferencePassingTest +# virtual methods modify the obj's mtxt, which should become visible in C++ +# To ensure that the original object instance was passed through, +# the pointer id of the received obj is returned by all pass_*() functions +# (and compared by the C++ caller with the originally passed obj id). +class PyReferencePassingTest1(m.ReferencePassingTest): + def __init__(self): + m.ReferencePassingTest.__init__(self) + + def pass_valu(self, obj): + obj.mtxt = obj.mtxt + "pass_valu" + return obj.id + + def pass_mref(self, obj): + obj.mtxt = obj.mtxt + "pass_mref" + return obj.id + + def pass_mptr(self, obj): + obj.mtxt = obj.mtxt + "pass_mptr" + return obj.id + + def pass_cref(self, obj): + with pytest.raises(Exception): # should be forbidden + obj.mtxt = obj.mtxt + "pass_cref" + return obj.id + + def pass_cptr(self, obj): + with pytest.raises(Exception): # should be forbidden + obj.mtxt = obj.mtxt + "pass_cptr" + return obj.id + + +# This class, in contrast to PyReferencePassingTest1, calls the base class methods as well, +# which will augment mtxt with a _MODIFIED stamp. +# These calls to the base class methods actually result in a 2nd call to the +# trampoline override dispatcher, requiring argument loading, which should pass +# references through as well, to make these tests succeed. +# argument is passed like this: C++ -> Python (call #1) -> C++ (call #2). +class PyReferencePassingTest2(m.ReferencePassingTest): + def __init__(self): + m.ReferencePassingTest.__init__(self) + + def pass_valu(self, obj): + obj.mtxt = obj.mtxt + "pass_valu" + return m.ReferencePassingTest.pass_valu(self, obj) + + def pass_mref(self, obj): + obj.mtxt = obj.mtxt + "pass_mref" + return m.ReferencePassingTest.pass_mref(self, obj) + + def pass_mptr(self, obj): + obj.mtxt = obj.mtxt + "pass_mptr" + return m.ReferencePassingTest.pass_mptr(self, obj) + + def pass_cref(self, obj): + with pytest.raises(Exception): # should be forbidden + obj.mtxt = obj.mtxt + "pass_cref" + return m.ReferencePassingTest.pass_cref(self, obj) + + def pass_cptr(self, obj): + with pytest.raises(Exception): # should be forbidden + obj.mtxt = obj.mtxt + "pass_cptr" + return m.ReferencePassingTest.pass_cptr(self, obj) + + +# roundtrip tests, creating a Passenger object in C++ that is passed by reference +# to a virtual method of a class derived in Python (PyReferencePassingTest1). +# If the object is correctly passed by reference, modifications should be visible +# by the C++ caller. The final obj's mtxt is returned by the check_* functions +# and validated here. Expected scheme: _[REF|COPY] +@pytest.mark.parametrize( + "f, expected", + [ + (m.check_roundtrip_valu, "_COPY"), # modification not passed back to C++ + (m.check_roundtrip_mref, "pass_mref_REF"), + (m.check_roundtrip_mptr, "pass_mptr_REF"), + ], +) +def test_refpassing1_roundtrip_modifyable(f, expected): + c = PyReferencePassingTest1() + assert f(c) == expected + + +@pytest.mark.parametrize( + "f, expected", + [ + # object passed as reference, but not modified + (m.check_roundtrip_cref, "_REF"), + (m.check_roundtrip_cptr, "_REF"), + ], +) +# PYPY always copies the argument (to ensure constness?) +@pytest.mark.skipif("env.PYPY") +@pytest.mark.xfail # maintaining constness isn't implemented yet +def test_refpassing1_roundtrip_const(f, expected): + c = PyReferencePassingTest1() + assert f(c) == expected + + +# Similar test as above, but now using PyReferencePassingTest2, calling +# to the C++ base class methods as well. +# Expected mtxt scheme: _MODIFIED_[REF|COPY] +@pytest.mark.parametrize( + "f, expected", + [ + # object copied, modification not passed back to C++ + (m.check_roundtrip_valu, "_COPY"), + (m.check_roundtrip_mref, "pass_mref_MODIFIED_REF"), + (m.check_roundtrip_mptr, "pass_mptr_MODIFIED_REF"), + ], +) +def test_refpassing2_roundtrip_modifyable(f, expected): + c = PyReferencePassingTest2() + assert f(c) == expected + + +@pytest.mark.parametrize( + "f, expected", + [ + # object passed as reference, but not modified + (m.check_roundtrip_cref, "_REF"), + (m.check_roundtrip_cptr, "_REF"), + ], +) +# PYPY always copies the argument (to ensure constness?) +@pytest.mark.skipif("env.PYPY") +@pytest.mark.xfail # maintaining constness isn't implemented yet +def test_refpassing2_roundtrip_const(f, expected): + c = PyReferencePassingTest2() + assert f(c) == expected From 2c95a3979115f208f5b0070170510bdddbd3ee69 Mon Sep 17 00:00:00 2001 From: Robert Haschke Date: Tue, 23 Mar 2021 01:18:16 +0100 Subject: [PATCH 2/2] Pass arguments from trampoline methods via return_value_policy::reference Otherwise, modifications applied by Python-coded method overrides would be applied to copies, even though the parameters were passed by pointer or reference. --- include/pybind11/pybind11.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 648300eef2..72401a37be 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -2179,7 +2179,7 @@ template function get_override(const T *this_ptr, const char *name) { pybind11::gil_scoped_acquire gil; \ pybind11::function override = pybind11::get_override(static_cast(this), name); \ if (override) { \ - auto o = override(__VA_ARGS__); \ + auto o = override.operator()(__VA_ARGS__); \ if (pybind11::detail::cast_is_temporary_value_reference::value) { \ static pybind11::detail::override_caster_t caster; \ return pybind11::detail::cast_ref(std::move(o), caster); \