Skip to content

Passing reference arguments to trampoline methods [smart_holder] #2916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: archive/smart_holder
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions include/pybind11/detail/smart_holder_type_casters.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,12 +685,10 @@ struct smart_holder_type_caster<std::unique_ptr<T, D>> : smart_holder_type_caste
smart_holder_type_caster_class_hooks {
static constexpr auto name = _<std::unique_ptr<T, D>>();

static handle cast(std::unique_ptr<T, D> &&src, return_value_policy policy, handle parent) {
if (policy != return_value_policy::automatic
&& policy != return_value_policy::reference_internal
&& policy != return_value_policy::move) {
static handle cast(std::unique_ptr<T, D> &&src, return_value_policy policy, handle) {
if (policy != return_value_policy::automatic && policy != return_value_policy::move) {
// SMART_HOLDER_WIP: IMPROVABLE: Error message.
throw cast_error("Invalid return_value_policy for unique_ptr.");
throw cast_error("Invalid return_value_policy: unique_ptr&& can only move");
}

auto src_raw_ptr = src.get();
Expand All @@ -712,26 +710,37 @@ struct smart_holder_type_caster<std::unique_ptr<T, D>> : smart_holder_type_caste
auto smhldr = pybindit::memory::smart_holder::from_unique_ptr(std::move(src));
tinfo->init_instance(inst_raw_ptr, static_cast<const void *>(&smhldr));

if (policy == return_value_policy::reference_internal)
keep_alive_impl(inst, parent);

return inst.release();
}
static handle cast(std::unique_ptr<T, D> &, return_value_policy, handle) {
throw cast_error("Passing non-const unique_ptr& is not supported. "
"If you want to transfer ownership, use unique_ptr&&. "
"If you want to return a reference, use unique_ptr const&.");
}

static handle
cast(const std::unique_ptr<T, D> &src, return_value_policy policy, handle parent) {
if (!src)
return none().release();
if (policy == return_value_policy::automatic)
policy = return_value_policy::reference_internal;
if (policy != return_value_policy::reference_internal)
throw cast_error("Invalid return_value_policy for unique_ptr&");
else if (policy == return_value_policy::reference && !parent)
; // passing from trampoline dispatcher: no parent available
else if (policy != return_value_policy::reference_internal)
throw cast_error(
"Invalid return_value_policy: unique_ptr const& expects reference_internal");
return smart_holder_type_caster<T>::cast(src.get(), policy, parent);
}

template <typename>
using cast_op_type = std::unique_ptr<T, D>;

operator std::unique_ptr<T, D>() { return this->template loaded_as_unique_ptr<D>(); }
// TODO: To allow passing unique_ptr const-references from Python to C++,
// we need to add another cast operator for const std::unique_ptr<T, D>&().
// The above cast operator always moves the held raw pointer, even if the argument only
// asked for a (const!) reference.
// See test_class_sh_basic.py::test_unique_ptr_cref_store_roundtrip
};

template <typename T, typename D>
Expand Down
2 changes: 1 addition & 1 deletion include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -2243,7 +2243,7 @@ template <class T> function get_override(const T *this_ptr, const char *name) {
pybind11::gil_scoped_acquire gil; \
pybind11::function override = pybind11::get_override(static_cast<const cname *>(this), name); \
if (override) { \
auto o = override(__VA_ARGS__); \
auto o = override.operator()<pybind11::return_value_policy::reference>(__VA_ARGS__); \
if (pybind11::detail::cast_is_temporary_value_reference<ret_type>::value) { \
static pybind11::detail::override_caster_t<ret_type> caster; \
return pybind11::detail::cast_ref<ret_type>(std::move(o), caster); \
Expand Down
56 changes: 35 additions & 21 deletions tests/test_class_sh_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <pybind11/smart_holder.h>

#include <cstdint>
#include <memory>
#include <string>
#include <vector>
Expand All @@ -17,18 +18,27 @@ struct atyp { // Short for "any type".
atyp(atyp &&other) { mtxt = other.mtxt + "_MvCtor"; }
};

struct uconsumer { // unique_ptr consumer
// clang-format off
struct store { // unique_ptr store:
std::unique_ptr<atyp> held;
bool valid() const { return static_cast<bool>(held); }

void pass_valu(std::unique_ptr<atyp> obj) { held = std::move(obj); }
void pass_rref(std::unique_ptr<atyp> &&obj) { held = std::move(obj); }
std::unique_ptr<atyp> rtrn_valu() { return std::move(held); }
std::unique_ptr<atyp> &rtrn_lref() { return held; }
const std::unique_ptr<atyp> &rtrn_cref() { return held; }
std::string pass_uq_valu(std::unique_ptr<atyp> obj) { held = std::move(obj); return held->mtxt; }
std::string pass_uq_rref(std::unique_ptr<atyp> &&obj) { held = std::move(obj); return held->mtxt; }
std::string pass_uq_cref(const std::unique_ptr<atyp> &obj) { return obj->mtxt; }
std::string pass_cptr(const atyp *obj) { return obj->mtxt; }
std::string pass_cref(const atyp &obj) { return obj.mtxt; }

std::unique_ptr<atyp> rtrn_uq_valu() { return std::move(held); }
std::unique_ptr<atyp>&& rtrn_uq_rref() { return std::move(held); }
std::unique_ptr<atyp>& rtrn_uq_mref() { return held; }
const std::unique_ptr<atyp>& rtrn_uq_cref() { return held; }
const atyp* rtrn_cptr() const { return held.get(); }
const atyp& rtrn_cref() const { return *held; }
atyp *rtrn_mptr() { return held.get(); }
atyp &rtrn_mref() { return *held; }
};

// clang-format off

atyp rtrn_valu() { atyp obj{"rtrn_valu"}; return obj; }
atyp&& rtrn_rref() { static atyp obj; obj.mtxt = "rtrn_rref"; return std::move(obj); }
Expand Down Expand Up @@ -68,12 +78,9 @@ std::string pass_udcp(std::unique_ptr<atyp const, sddc> obj) { return "pass_udcp

// Helpers for testing.
std::string get_mtxt(atyp const &obj) { return obj.mtxt; }
std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast<std::ptrdiff_t>(&obj); }
std::uintptr_t get_ptr(atyp const &obj) { return reinterpret_cast<std::uintptr_t>(&obj); }

std::unique_ptr<atyp> unique_ptr_roundtrip(std::unique_ptr<atyp> obj) { return obj; }
const std::unique_ptr<atyp> &unique_ptr_cref_roundtrip(const std::unique_ptr<atyp> &obj) {
return obj;
}

struct SharedPtrStash {
std::vector<std::shared_ptr<const atyp>> stash;
Expand All @@ -84,7 +91,7 @@ struct SharedPtrStash {
} // namespace pybind11_tests

PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::atyp)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::uconsumer)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::store)
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_basic::SharedPtrStash)

namespace pybind11_tests {
Expand All @@ -102,7 +109,7 @@ TEST_SUBMODULE(class_sh_basic, m) {
m.def("rtrn_valu", rtrn_valu);
m.def("rtrn_rref", rtrn_rref);
m.def("rtrn_cref", rtrn_cref);
m.def("rtrn_mref", rtrn_mref);
m.def("rtrn_mref", rtrn_mref, py::return_value_policy::reference);
m.def("rtrn_cptr", rtrn_cptr);
m.def("rtrn_mptr", rtrn_mptr);

Expand Down Expand Up @@ -130,22 +137,29 @@ TEST_SUBMODULE(class_sh_basic, m) {
m.def("pass_udmp", pass_udmp);
m.def("pass_udcp", pass_udcp);

py::classh<uconsumer>(m, "uconsumer")
py::classh<store>(m, "store")
.def(py::init<>())
.def("valid", &uconsumer::valid)
.def("pass_valu", &uconsumer::pass_valu)
.def("pass_rref", &uconsumer::pass_rref)
.def("rtrn_valu", &uconsumer::rtrn_valu)
.def("rtrn_lref", &uconsumer::rtrn_lref)
.def("rtrn_cref", &uconsumer::rtrn_cref);
.def("valid", &store::valid)
.def("pass_uq_valu", &store::pass_uq_valu)
.def("pass_uq_rref", &store::pass_uq_rref)
.def("pass_uq_cref", &store::pass_uq_cref)
.def("pass_cptr", &store::pass_cptr)
.def("pass_cref", &store::pass_cref)
.def("rtrn_uq_valu", &store::rtrn_uq_valu)
.def("rtrn_uq_rref", &store::rtrn_uq_rref)
.def("rtrn_uq_mref", &store::rtrn_uq_mref)
.def("rtrn_uq_cref", &store::rtrn_uq_cref)
.def("rtrn_mptr", &store::rtrn_mptr, py::return_value_policy::reference_internal)
.def("rtrn_mref", &store::rtrn_mref, py::return_value_policy::reference_internal)
.def("rtrn_cptr", &store::rtrn_cptr, py::return_value_policy::reference_internal)
.def("rtrn_cref", &store::rtrn_cref, py::return_value_policy::reference_internal);

// Helpers for testing.
// These require selected functions above to work first, as indicated:
m.def("get_mtxt", get_mtxt); // pass_cref
m.def("get_ptr", get_ptr); // pass_cref

m.def("unique_ptr_roundtrip", unique_ptr_roundtrip); // pass_uqmp, rtrn_uqmp
m.def("unique_ptr_cref_roundtrip", unique_ptr_cref_roundtrip);

py::classh<SharedPtrStash>(m, "SharedPtrStash")
.def(py::init<>())
Expand Down
158 changes: 110 additions & 48 deletions tests/test_class_sh_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@ def test_atyp_constructors():
assert obj.__class__.__name__ == "atyp"


def check_regex(expected, actual):
result = re.match(expected + "$", actual)
if result is None:
pytest.fail("expected: '{}' != actual: '{}'".format(expected, actual))


@pytest.mark.parametrize(
"rtrn_f, expected",
[
(m.rtrn_valu, "rtrn_valu(_MvCtor)*_MvCtor"),
(m.rtrn_rref, "rtrn_rref(_MvCtor)*_MvCtor"),
(m.rtrn_cref, "rtrn_cref(_MvCtor)*_CpCtor"),
(m.rtrn_mref, "rtrn_mref(_MvCtor)*_CpCtor"),
(m.rtrn_valu, "rtrn_valu(_MvCtor){1,3}"),
(m.rtrn_rref, "rtrn_rref(_MvCtor){1}"),
(m.rtrn_cref, "rtrn_cref_CpCtor"),
(m.rtrn_mref, "rtrn_mref"),
(m.rtrn_cptr, "rtrn_cptr"),
(m.rtrn_mptr, "rtrn_mptr"),
(m.rtrn_shmp, "rtrn_shmp"),
Expand All @@ -34,25 +40,25 @@ def test_atyp_constructors():
],
)
def test_cast(rtrn_f, expected):
assert re.match(expected, m.get_mtxt(rtrn_f()))
check_regex(expected, m.get_mtxt(rtrn_f()))


@pytest.mark.parametrize(
"pass_f, mtxt, expected",
[
(m.pass_valu, "Valu", "pass_valu:Valu(_MvCtor)*_CpCtor"),
(m.pass_cref, "Cref", "pass_cref:Cref(_MvCtor)*_MvCtor"),
(m.pass_mref, "Mref", "pass_mref:Mref(_MvCtor)*_MvCtor"),
(m.pass_cptr, "Cptr", "pass_cptr:Cptr(_MvCtor)*_MvCtor"),
(m.pass_mptr, "Mptr", "pass_mptr:Mptr(_MvCtor)*_MvCtor"),
(m.pass_shmp, "Shmp", "pass_shmp:Shmp(_MvCtor)*_MvCtor"),
(m.pass_shcp, "Shcp", "pass_shcp:Shcp(_MvCtor)*_MvCtor"),
(m.pass_uqmp, "Uqmp", "pass_uqmp:Uqmp(_MvCtor)*_MvCtor"),
(m.pass_uqcp, "Uqcp", "pass_uqcp:Uqcp(_MvCtor)*_MvCtor"),
(m.pass_valu, "Valu", "pass_valu:Valu(_MvCtor){1,2}_CpCtor"),
(m.pass_cref, "Cref", "pass_cref:Cref(_MvCtor){1,2}"),
(m.pass_mref, "Mref", "pass_mref:Mref(_MvCtor){1,2}"),
(m.pass_cptr, "Cptr", "pass_cptr:Cptr(_MvCtor){1,2}"),
(m.pass_mptr, "Mptr", "pass_mptr:Mptr(_MvCtor){1,2}"),
(m.pass_shmp, "Shmp", "pass_shmp:Shmp(_MvCtor){1,2}"),
(m.pass_shcp, "Shcp", "pass_shcp:Shcp(_MvCtor){1,2}"),
(m.pass_uqmp, "Uqmp", "pass_uqmp:Uqmp(_MvCtor){1,2}"),
(m.pass_uqcp, "Uqcp", "pass_uqcp:Uqcp(_MvCtor){1,2}"),
],
)
def test_load_with_mtxt(pass_f, mtxt, expected):
assert re.match(expected, pass_f(m.atyp(mtxt)))
check_regex(expected, pass_f(m.atyp(mtxt)))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -111,53 +117,109 @@ def test_unique_ptr_roundtrip(num_round_trips=1000):
for _ in range(num_round_trips):
id_orig = id(recycled)
recycled = m.unique_ptr_roundtrip(recycled)
assert re.match("passenger(_MvCtor)*_MvCtor", m.get_mtxt(recycled))
check_regex("passenger(_MvCtor){1,2}", m.get_mtxt(recycled))
id_rtrn = id(recycled)
# Ensure the returned object is a different Python instance.
assert id_rtrn != id_orig
id_orig = id_rtrn


# This currently fails, because a unique_ptr is always loaded by value
# due to pybind11/detail/smart_holder_type_casters.h:689
# I think, we need to provide more cast operators.
@pytest.mark.skip
def test_unique_ptr_cref_roundtrip(num_round_trips=1000):
orig = m.atyp("passenger")
id_orig = id(orig)
# Validate moving an object from Python into a C++ object store
@pytest.mark.parametrize("pass_f", [m.store.pass_uq_valu, m.store.pass_uq_rref])
def test_unique_ptr_moved(pass_f):
store = m.store()
orig = m.atyp("O")
mtxt_orig = m.get_mtxt(orig)
ptr_orig = m.get_ptr(orig)
assert re.match("O(_MvCtor){1,2}", mtxt_orig)

recycled = m.unique_ptr_cref_roundtrip(orig)
assert m.get_mtxt(orig) == mtxt_orig
assert m.get_mtxt(recycled) == mtxt_orig
assert id(recycled) == id_orig
pass_f(store, orig) # pass object to C++ store c
with pytest.raises(ValueError) as excinfo:
m.get_mtxt(orig)
assert "Python instance was disowned" in str(excinfo.value)

del orig
recycled = store.rtrn_uq_cref()
assert m.get_ptr(recycled) == ptr_orig # underlying C++ object doesn't change
assert m.get_mtxt(recycled) == mtxt_orig # object was not moved or copied


# This series of roundtrip tests checks how an object instance moved from
# Python to C++ (into store) can be later returned back to Python.
@pytest.mark.parametrize(
"pass_f, rtrn_f, moved_out, moved_in",
"rtrn_f, moved_in",
[
(m.uconsumer.pass_valu, m.uconsumer.rtrn_valu, True, True),
(m.uconsumer.pass_rref, m.uconsumer.rtrn_valu, True, True),
(m.uconsumer.pass_valu, m.uconsumer.rtrn_lref, True, False),
(m.uconsumer.pass_valu, m.uconsumer.rtrn_cref, True, False),
(m.store.rtrn_uq_valu, True), # moved back in
(m.store.rtrn_uq_rref, True), # moved back in
(m.store.rtrn_uq_mref, None), # forbidden
(m.store.rtrn_uq_cref, False), # fetched by reference
(m.store.rtrn_mref, None), # forbidden
(m.store.rtrn_cref, False), # fetched by reference
(m.store.rtrn_mptr, None), # forbidden
(m.store.rtrn_cptr, False), # fetched by reference
],
)
def test_unique_ptr_consumer_roundtrip(pass_f, rtrn_f, moved_out, moved_in):
c = m.uconsumer()
assert not c.valid()
recycled = m.atyp("passenger")
mtxt_orig = m.get_mtxt(recycled)
assert re.match("passenger_(MvCtor){1,2}", mtxt_orig)

pass_f(c, recycled)
if moved_out:
with pytest.raises(ValueError) as excinfo:
m.get_mtxt(recycled)
assert "Python instance was disowned" in str(excinfo.value)

recycled = rtrn_f(c)
assert c.valid() != moved_in
assert m.get_mtxt(recycled) == mtxt_orig
def test_unique_ptr_store_roundtrip(rtrn_f, moved_in):
c = m.store()
orig = m.atyp("passenger")
ptr_orig = m.get_ptr(orig)

c.pass_uq_valu(orig) # pass object to C++ store c
try:
recycled = rtrn_f(c) # retrieve object back from C++
except RuntimeError as excinfo: # expect failure for rtrn_uq_lref
assert (
moved_in is None
and "Passing non-const unique_ptr& is not supported" in str(excinfo)
)
return

assert m.get_ptr(recycled) == ptr_orig # do we yield the same object?
if moved_in: # store should have given up ownership?
assert c.valid() is False
else: # store still helds the object
assert c.valid() is True
del recycled
assert c.valid() is True


# Additionally to the above test_unique_ptr_store_roundtrip, this test
# validates that an object initially moved from Python to C++ can be returned
# to Python as a *const* reference/raw pointer/unique_ptr *and*, subsequently,
# passed from Python to C++ again. There shouldn't be any copy or move operation
# involved (We want the object to be passed by reference!)
@pytest.mark.parametrize(
"rtrn_f",
[m.store.rtrn_uq_cref, m.store.rtrn_cref, m.store.rtrn_cptr],
)
@pytest.mark.parametrize(
"pass_f",
[
# This fails with: ValueError: Cannot disown non-owning holder (loaded_as_unique_ptr).
# This could, at most, work for the combination rtrn_uq_cref() + pass_uq_cref(),
# i.e. fetching a unique_ptr const-ref from C++ and passing the very same reference back.
# Currently, it is forbidden - by design - to pass a unique_ptr const-ref to C++.
# unique_ptrs are always moved (if possible).
# To allow this use case, smart_holder would need to store the unique_ptr reference,
# originally received from C++, e.g. using a union of unique_ptr + shared_ptr.
pytest.param(m.store.pass_uq_cref, marks=pytest.mark.xfail),
m.store.pass_cptr,
m.store.pass_cref,
],
)
def test_unique_ptr_cref_store_roundtrip(rtrn_f, pass_f):
c = m.store()
passenger = m.atyp("passenger")
mtxt_orig = m.get_mtxt(passenger)
ptr_orig = m.get_ptr(passenger)

# moves passenger to C++ (checked in test_unique_ptr_store_roundtrip)
c.pass_uq_valu(passenger)

for _ in range(10):
cref = rtrn_f(c) # fetches const reference, should keep-alive parent c
assert pass_f(c, cref) == mtxt_orig # no copy/move happened?
assert m.get_ptr(cref) == ptr_orig # it's still the same raw pointer


def test_py_type_handle_of_atyp():
Expand Down
Loading