Skip to content

Commit 3c2fd06

Browse files
committed
fix: Invert input/output context if a Python function is called from C++
When the Callable itself is an input (parameter) to a C++ function, its arguments are outputs (C++ passes them to the Python callback), and vice versa. Therefore, we must invert them if C++ calls a Python function, but keep them the same in the other direction.
1 parent 81817ae commit 3c2fd06

6 files changed

Lines changed: 26 additions & 18 deletions

File tree

include/pybind11/detail/descr.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,5 +222,10 @@ constexpr descr<N + 4, Ts...> return_descr(const descr<N, Ts...> &descr) {
222222
return const_name("@$") + descr + const_name("@!");
223223
}
224224

225+
template <size_t N, typename... Ts>
226+
constexpr descr<N + 4, Ts...> inv_descr(const descr<N, Ts...> &descr) {
227+
return const_name("@~") + descr + const_name("@!");
228+
}
229+
225230
PYBIND11_NAMESPACE_END(detail)
226231
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)

include/pybind11/functional.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,8 @@ struct type_caster<std::function<Return(Args...)>> {
138138
PYBIND11_TYPE_CASTER(
139139
type,
140140
const_name("collections.abc.Callable[[")
141-
+ ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster<Args>::name)...)
142-
+ const_name("], ") + ::pybind11::detail::return_descr(make_caster<retval_type>::name)
143-
+ const_name("]"));
141+
+ ::pybind11::detail::concat(::pybind11::detail::inv_descr(make_caster<Args>::name)...)
142+
+ const_name("], ") + make_caster<retval_type>::name + const_name("]"));
144143
};
145144

146145
PYBIND11_NAMESPACE_END(detail)

include/pybind11/pybind11.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel
185185
signature += *++pc;
186186
} else if (c == '@') {
187187
// `@^ ... @!` and `@$ ... @!` are used to force arg/return value type (see
188-
// typing::Callable/detail::arg_descr/detail::return_descr)
188+
// typing::Callable/detail::arg_descr/detail::return_descr).
189+
// `@~ ... @!` inverts the current context (see detail::inv_descr).
189190
if (*(pc + 1) == '^') {
190191
is_return_value.emplace(false);
191192
++pc;
@@ -196,6 +197,11 @@ inline std::string generate_function_signature(const char *type_caster_name_fiel
196197
++pc;
197198
continue;
198199
}
200+
if (*(pc + 1) == '~') {
201+
is_return_value.emplace(!is_return_value.top());
202+
++pc;
203+
continue;
204+
}
199205
if (*(pc + 1) == '!') {
200206
is_return_value.pop();
201207
++pc;

include/pybind11/typing.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,18 +194,16 @@ struct handle_type_name<typing::Callable<Return(Args...)>> {
194194
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
195195
static constexpr auto name
196196
= const_name("collections.abc.Callable[[")
197-
+ ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster<Args>::name)...)
198-
+ const_name("], ") + ::pybind11::detail::return_descr(make_caster<retval_type>::name)
199-
+ const_name("]");
197+
+ ::pybind11::detail::concat(::pybind11::detail::inv_descr(make_caster<Args>::name)...)
198+
+ const_name("], ") + make_caster<retval_type>::name + const_name("]");
200199
};
201200

202201
template <typename Return>
203202
struct handle_type_name<typing::Callable<Return(ellipsis)>> {
204203
// PEP 484 specifies this syntax for defining only return types of callables
205204
using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
206205
static constexpr auto name = const_name("collections.abc.Callable[..., ")
207-
+ ::pybind11::detail::return_descr(make_caster<retval_type>::name)
208-
+ const_name("]");
206+
+ make_caster<retval_type>::name + const_name("]");
209207
};
210208

211209
template <typename T>

tests/test_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_cpp_function_roundtrip():
140140
def test_function_signatures(doc):
141141
assert (
142142
doc(m.test_callback3)
143-
== "test_callback3(arg0: collections.abc.Callable[[typing.SupportsInt | typing.SupportsIndex], int]) -> str"
143+
== "test_callback3(arg0: collections.abc.Callable[[int], typing.SupportsInt | typing.SupportsIndex]) -> str"
144144
)
145145
assert (
146146
doc(m.test_callback4)

tests/test_pytypes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -976,14 +976,14 @@ def test_iterator_annotations(doc):
976976
def test_fn_annotations(doc):
977977
assert (
978978
doc(m.annotate_fn)
979-
== "annotate_fn(arg0: collections.abc.Callable[[list[str], str], int]) -> None"
979+
== "annotate_fn(arg0: collections.abc.Callable[[list[str], str], typing.SupportsInt | typing.SupportsIndex]) -> None"
980980
)
981981

982982

983983
def test_fn_return_only(doc):
984984
assert (
985985
doc(m.annotate_fn_only_return)
986-
== "annotate_fn_only_return(arg0: collections.abc.Callable[..., int]) -> None"
986+
== "annotate_fn_only_return(arg0: collections.abc.Callable[..., typing.SupportsInt | typing.SupportsIndex]) -> None"
987987
)
988988

989989

@@ -1085,7 +1085,7 @@ def test_literal(doc):
10851085
)
10861086
assert (
10871087
doc(m.identity_literal_arrow_with_callable)
1088-
== 'identity_literal_arrow_with_callable(arg0: collections.abc.Callable[[typing.Literal["->"], float | int], float]) -> collections.abc.Callable[[typing.Literal["->"], float | int], float]'
1088+
== 'identity_literal_arrow_with_callable(arg0: collections.abc.Callable[[typing.Literal["->"], float], float | int]) -> collections.abc.Callable[[typing.Literal["->"], float | int], float]'
10891089
)
10901090
assert (
10911091
doc(m.identity_literal_all_special_chars)
@@ -1325,27 +1325,27 @@ def test_arg_return_type_hints(doc, backport_typehints):
13251325
# Callable<R(A)> identity
13261326
assert (
13271327
doc(m.identity_callable)
1328-
== "identity_callable(arg0: collections.abc.Callable[[float | int], float]) -> collections.abc.Callable[[float | int], float]"
1328+
== "identity_callable(arg0: collections.abc.Callable[[float], float | int]) -> collections.abc.Callable[[float | int], float]"
13291329
)
13301330
# Callable<R(...)> identity
13311331
assert (
13321332
doc(m.identity_callable_ellipsis)
1333-
== "identity_callable_ellipsis(arg0: collections.abc.Callable[..., float]) -> collections.abc.Callable[..., float]"
1333+
== "identity_callable_ellipsis(arg0: collections.abc.Callable[..., float | int]) -> collections.abc.Callable[..., float]"
13341334
)
13351335
# Nested Callable<R(A)> identity
13361336
assert (
13371337
doc(m.identity_nested_callable)
1338-
== "identity_nested_callable(arg0: collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float | int], float]]) -> collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float | int], float]]"
1338+
== "identity_nested_callable(arg0: collections.abc.Callable[[collections.abc.Callable[[float | int], float]], collections.abc.Callable[[float], float | int]]) -> collections.abc.Callable[[collections.abc.Callable[[float], float | int]], collections.abc.Callable[[float | int], float]]"
13391339
)
13401340
# Callable<R(A)>
13411341
assert (
13421342
doc(m.apply_callable)
1343-
== "apply_callable(arg0: float | int, arg1: collections.abc.Callable[[float | int], float]) -> float"
1343+
== "apply_callable(arg0: float | int, arg1: collections.abc.Callable[[float], float | int]) -> float"
13441344
)
13451345
# Callable<R(...)>
13461346
assert (
13471347
doc(m.apply_callable_ellipsis)
1348-
== "apply_callable_ellipsis(arg0: float | int, arg1: collections.abc.Callable[..., float]) -> float"
1348+
== "apply_callable_ellipsis(arg0: float | int, arg1: collections.abc.Callable[..., float | int]) -> float"
13491349
)
13501350
# Union<T1, T2>
13511351
assert (

0 commit comments

Comments
 (0)