Skip to content

Commit ccb7129

Browse files
swolchokrwgk
andauthored
Improve performance of enum_ operators by going back to specific implementation (#5887)
* Improve performance of enum_ operators by going back to specific implementation test_enum needs a patch because ops are now overloaded and this affects their docstrings. * outline call_impl to save on code size This does cause more move constructions, as shown by the needed update to test_copy_move. Up to reviewers whether they want more code size or more moves. * add function_ref.h to PYBIND11_HEADERS. * Update test_copy_move tests with C++17 passing values just so we can see mostly-not-red tests * Remove stray TODO * fix clang-tidy * fix clang-tidy again. add function_ref.h to test_files.py * Add static assertion for function_ref lifetime safety in call_impl Add a static_assert to document and enforce that function_ref is trivially copyable, ensuring safe pass-by-value usage. This also documents the lifetime safety guarantees: function_ref is created from cap->f which lives in the capture object, and is only used synchronously within call_impl without being stored beyond its scope. * Add #undef cleanup for enum operator macros Undefine all enum operator macros after their last use to prevent macro pollution and follow the existing code pattern. This matches the cleanup pattern used for the previous enum operator macros. * Rename PYBIND11_THROW to PYBIND11_ENUM_OP_THROW_TYPE_ERROR Rename the macro to be more specific and avoid potential clashes with public macros. The new name clearly indicates it's scoped to enum operations and describes its purpose (throwing a type error). * Clarify comments in function_ref.h Replace vague comments about 'extensions to <functional>' and 'functions' with a clearer description that this is a header-only class template similar to std::function but with non-owning semantics. This makes it clear that it's template-only and requires no additional library linking. --------- Co-authored-by: Ralf W. Grosse-Kunstleve <rgrossekunst@nvidia.com>
1 parent e8e8d6a commit ccb7129

File tree

7 files changed

+238
-58
lines changed

7 files changed

+238
-58
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ set(PYBIND11_HEADERS
188188
include/pybind11/detail/dynamic_raw_ptr_cast_if_possible.h
189189
include/pybind11/detail/exception_translation.h
190190
include/pybind11/detail/function_record_pyobject.h
191+
include/pybind11/detail/function_ref.h
191192
include/pybind11/detail/holder_caster_foreign_helpers.h
192193
include/pybind11/detail/init.h
193194
include/pybind11/detail/internals.h

include/pybind11/detail/common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,14 @@
167167
# define PYBIND11_NOINLINE __attribute__((noinline)) inline
168168
#endif
169169

170+
#if defined(_MSC_VER)
171+
# define PYBIND11_ALWAYS_INLINE __forceinline
172+
#elif defined(__GNUC__)
173+
# define PYBIND11_ALWAYS_INLINE __attribute__((__always_inline__)) inline
174+
#else
175+
# define PYBIND11_ALWAYS_INLINE inline
176+
#endif
177+
170178
#if defined(__MINGW32__)
171179
// For unknown reasons all PYBIND11_DEPRECATED member trigger a warning when declared
172180
// whether it is used or not
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
10+
//
11+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
12+
// See https://llvm.org/LICENSE.txt for license information.
13+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14+
//
15+
//===----------------------------------------------------------------------===//
16+
//
17+
// This file contains a header-only class template that provides functionality
18+
// similar to std::function but with non-owning semantics. It is a template-only
19+
// implementation that requires no additional library linking.
20+
//
21+
//===----------------------------------------------------------------------===//
22+
23+
/// An efficient, type-erasing, non-owning reference to a callable. This is
24+
/// intended for use as the type of a function parameter that is not used
25+
/// after the function in question returns.
26+
///
27+
/// This class does not own the callable, so it is not in general safe to store
28+
/// a FunctionRef.
29+
30+
// pybind11: modified again from executorch::runtime::FunctionRef
31+
// - renamed back to function_ref
32+
// - use pybind11 enable_if_t, remove_cvref_t, and remove_reference_t
33+
// - lint suppressions
34+
35+
// torch::executor: modified from llvm::function_ref
36+
// - renamed to FunctionRef
37+
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
38+
// - use namespaced internal::remove_cvref_t
39+
40+
#pragma once
41+
42+
#include <pybind11/detail/common.h>
43+
44+
#include <cstdint>
45+
#include <type_traits>
46+
#include <utility>
47+
48+
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
49+
PYBIND11_NAMESPACE_BEGIN(detail)
50+
51+
//===----------------------------------------------------------------------===//
52+
// Features from C++20
53+
//===----------------------------------------------------------------------===//
54+
55+
template <typename Fn>
56+
class function_ref;
57+
58+
template <typename Ret, typename... Params>
59+
class function_ref<Ret(Params...)> {
60+
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
61+
intptr_t callable;
62+
63+
template <typename Callable>
64+
// NOLINTNEXTLINE(performance-unnecessary-value-param)
65+
static Ret callback_fn(intptr_t callable, Params... params) {
66+
// NOLINTNEXTLINE(performance-no-int-to-ptr)
67+
return (*reinterpret_cast<Callable *>(callable))(std::forward<Params>(params)...);
68+
}
69+
70+
public:
71+
function_ref() = default;
72+
// NOLINTNEXTLINE(google-explicit-constructor)
73+
function_ref(std::nullptr_t) {}
74+
75+
template <typename Callable>
76+
// NOLINTNEXTLINE(google-explicit-constructor)
77+
function_ref(
78+
Callable &&callable,
79+
// This is not the copy-constructor.
80+
enable_if_t<!std::is_same<remove_cvref_t<Callable>, function_ref>::value> * = nullptr,
81+
// Functor must be callable and return a suitable type.
82+
enable_if_t<
83+
std::is_void<Ret>::value
84+
|| std::is_convertible<decltype(std::declval<Callable>()(std::declval<Params>()...)),
85+
Ret>::value> * = nullptr)
86+
: callback(callback_fn<remove_reference_t<Callable>>),
87+
callable(reinterpret_cast<intptr_t>(&callable)) {}
88+
89+
// NOLINTNEXTLINE(performance-unnecessary-value-param)
90+
Ret operator()(Params... params) const {
91+
return callback(callable, std::forward<Params>(params)...);
92+
}
93+
94+
explicit operator bool() const { return callback; }
95+
96+
bool operator==(const function_ref<Ret(Params...)> &Other) const {
97+
return callable == Other.callable;
98+
}
99+
};
100+
PYBIND11_NAMESPACE_END(detail)
101+
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)

include/pybind11/pybind11.h

Lines changed: 110 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "detail/dynamic_raw_ptr_cast_if_possible.h"
1414
#include "detail/exception_translation.h"
1515
#include "detail/function_record_pyobject.h"
16+
#include "detail/function_ref.h"
1617
#include "detail/init.h"
1718
#include "detail/native_enum_data.h"
1819
#include "detail/using_smart_holder.h"
@@ -386,6 +387,46 @@ class cpp_function : public function {
386387
return unique_function_record(new detail::function_record());
387388
}
388389

390+
private:
391+
// This is outlined from the dispatch lambda in initialize to save
392+
// on code size. Crucially, we use function_ref to type-erase the
393+
// actual function lambda so that we can get code reuse for
394+
// functions with the same Return, Args, and Guard.
395+
template <typename Return, typename Guard, typename ArgsConverter, typename... Args>
396+
static handle call_impl(detail::function_call &call, detail::function_ref<Return(Args...)> f) {
397+
using namespace detail;
398+
// Static assertion: function_ref must be trivially copyable to ensure safe pass-by-value.
399+
// Lifetime safety: The function_ref is created from cap->f which lives in the capture
400+
// object stored in the function record, and is only used synchronously within this
401+
// function call. It is never stored beyond the scope of call_impl.
402+
static_assert(std::is_trivially_copyable<detail::function_ref<Return(Args...)>>::value,
403+
"function_ref must be trivially copyable for safe pass-by-value usage");
404+
using cast_out
405+
= make_caster<conditional_t<std::is_void<Return>::value, void_type, Return>>;
406+
407+
ArgsConverter args_converter;
408+
if (!args_converter.load_args(call)) {
409+
return PYBIND11_TRY_NEXT_OVERLOAD;
410+
}
411+
412+
/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
413+
return_value_policy policy
414+
= return_value_policy_override<Return>::policy(call.func.policy);
415+
416+
/* Perform the function call */
417+
handle result;
418+
if (call.func.is_setter) {
419+
(void) std::move(args_converter).template call<Return, Guard>(f);
420+
result = none().release();
421+
} else {
422+
result = cast_out::cast(
423+
std::move(args_converter).template call<Return, Guard>(f), policy, call.parent);
424+
}
425+
426+
return result;
427+
}
428+
429+
protected:
389430
/// Special internal constructor for functors, lambda functions, etc.
390431
template <typename Func, typename Return, typename... Args, typename... Extra>
391432
void initialize(Func &&f, Return (*)(Args...), const Extra &...extra) {
@@ -448,13 +489,6 @@ class cpp_function : public function {
448489

449490
/* Dispatch code which converts function arguments and performs the actual function call */
450491
rec->impl = [](function_call &call) -> handle {
451-
cast_in args_converter;
452-
453-
/* Try to cast the function arguments into the C++ domain */
454-
if (!args_converter.load_args(call)) {
455-
return PYBIND11_TRY_NEXT_OVERLOAD;
456-
}
457-
458492
/* Invoke call policy pre-call hook */
459493
process_attributes<Extra...>::precall(call);
460494

@@ -463,24 +497,11 @@ class cpp_function : public function {
463497
: call.func.data[0]);
464498
auto *cap = const_cast<capture *>(reinterpret_cast<const capture *>(data));
465499

466-
/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
467-
return_value_policy policy
468-
= return_value_policy_override<Return>::policy(call.func.policy);
469-
470-
/* Function scope guard -- defaults to the compile-to-nothing `void_type` */
471-
using Guard = extract_guard_t<Extra...>;
472-
473-
/* Perform the function call */
474-
handle result;
475-
if (call.func.is_setter) {
476-
(void) std::move(args_converter).template call<Return, Guard>(cap->f);
477-
result = none().release();
478-
} else {
479-
result = cast_out::cast(
480-
std::move(args_converter).template call<Return, Guard>(cap->f),
481-
policy,
482-
call.parent);
483-
}
500+
auto result = call_impl<Return,
501+
/* Function scope guard -- defaults to the compile-to-nothing
502+
`void_type` */
503+
extract_guard_t<Extra...>,
504+
cast_in>(call, detail::function_ref<Return(Args...)>(cap->f));
484505

485506
/* Invoke call policy post-call hook */
486507
process_attributes<Extra...>::postcall(call, result);
@@ -2245,7 +2266,7 @@ class class_ : public detail::generic_type {
22452266
static void add_base(detail::type_record &) {}
22462267

22472268
template <typename Func, typename... Extra>
2248-
class_ &def(const char *name_, Func &&f, const Extra &...extra) {
2269+
PYBIND11_ALWAYS_INLINE class_ &def(const char *name_, Func &&f, const Extra &...extra) {
22492270
cpp_function cf(method_adaptor<type>(std::forward<Func>(f)),
22502271
name(name_),
22512272
is_method(*this),
@@ -2830,38 +2851,13 @@ struct enum_base {
28302851
pos_only())
28312852

28322853
if (is_convertible) {
2833-
PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b));
2834-
PYBIND11_ENUM_OP_CONV_LHS("__ne__", b.is_none() || !a.equal(b));
2835-
28362854
if (is_arithmetic) {
2837-
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
2838-
PYBIND11_ENUM_OP_CONV("__gt__", a > b);
2839-
PYBIND11_ENUM_OP_CONV("__le__", a <= b);
2840-
PYBIND11_ENUM_OP_CONV("__ge__", a >= b);
2841-
PYBIND11_ENUM_OP_CONV("__and__", a & b);
2842-
PYBIND11_ENUM_OP_CONV("__rand__", a & b);
2843-
PYBIND11_ENUM_OP_CONV("__or__", a | b);
2844-
PYBIND11_ENUM_OP_CONV("__ror__", a | b);
2845-
PYBIND11_ENUM_OP_CONV("__xor__", a ^ b);
2846-
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
28472855
m_base.attr("__invert__")
28482856
= cpp_function([](const object &arg) { return ~(int_(arg)); },
28492857
name("__invert__"),
28502858
is_method(m_base),
28512859
pos_only());
28522860
}
2853-
} else {
2854-
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
2855-
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true);
2856-
2857-
if (is_arithmetic) {
2858-
#define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!");
2859-
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), PYBIND11_THROW);
2860-
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), PYBIND11_THROW);
2861-
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), PYBIND11_THROW);
2862-
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), PYBIND11_THROW);
2863-
#undef PYBIND11_THROW
2864-
}
28652861
}
28662862

28672863
#undef PYBIND11_ENUM_OP_CONV_LHS
@@ -2977,6 +2973,69 @@ class enum_ : public class_<Type> {
29772973

29782974
def(init([](Scalar i) { return static_cast<Type>(i); }), arg("value"));
29792975
def_property_readonly("value", [](Type value) { return (Scalar) value; }, pos_only());
2976+
#define PYBIND11_ENUM_OP_SAME_TYPE(op, expr) \
2977+
def(op, [](Type a, Type b) { return expr; }, pybind11::name(op), arg("other"), pos_only())
2978+
#define PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE(op, expr) \
2979+
def(op, [](Type a, Type *b_ptr) { return expr; }, pybind11::name(op), arg("other"), pos_only())
2980+
#define PYBIND11_ENUM_OP_SCALAR(op, op_expr) \
2981+
def( \
2982+
op, \
2983+
[](Type a, Scalar b) { return static_cast<Scalar>(a) op_expr b; }, \
2984+
pybind11::name(op), \
2985+
arg("other"), \
2986+
pos_only())
2987+
#define PYBIND11_ENUM_OP_CONV_ARITHMETIC(op, op_expr) \
2988+
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
2989+
PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast<Scalar>(a) op_expr static_cast<Scalar>(b)); \
2990+
PYBIND11_ENUM_OP_SCALAR(op, op_expr)
2991+
#define PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior) \
2992+
def( \
2993+
op, \
2994+
[](Type, const object &) { strict_behavior; }, \
2995+
pybind11::name(op), \
2996+
arg("other"), \
2997+
pos_only())
2998+
#define PYBIND11_ENUM_OP_STRICT_ARITHMETIC(op, op_expr, strict_behavior) \
2999+
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
3000+
PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast<Scalar>(a) op_expr static_cast<Scalar>(b)); \
3001+
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior);
3002+
3003+
PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__eq__", b_ptr && a == *b_ptr);
3004+
PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__ne__", !b_ptr || a != *b_ptr);
3005+
if (std::is_convertible<Type, Scalar>::value) {
3006+
PYBIND11_ENUM_OP_SCALAR("__eq__", ==);
3007+
PYBIND11_ENUM_OP_SCALAR("__ne__", !=);
3008+
if (is_arithmetic) {
3009+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__lt__", <);
3010+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__gt__", >);
3011+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__le__", <=);
3012+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ge__", >=);
3013+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__and__", &);
3014+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rand__", &);
3015+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__or__", |);
3016+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ror__", |);
3017+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__xor__", ^);
3018+
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rxor__", ^);
3019+
}
3020+
} else if (is_arithmetic) {
3021+
#define PYBIND11_ENUM_OP_THROW_TYPE_ERROR \
3022+
throw type_error("Expected an enumeration of matching type!");
3023+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__lt__", <, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
3024+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__gt__", >, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
3025+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__le__", <=, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
3026+
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__ge__", >=, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
3027+
#undef PYBIND11_ENUM_OP_THROW_TYPE_ERROR
3028+
}
3029+
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__eq__", return false);
3030+
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__ne__", return true);
3031+
3032+
#undef PYBIND11_ENUM_OP_SAME_TYPE
3033+
#undef PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE
3034+
#undef PYBIND11_ENUM_OP_SCALAR
3035+
#undef PYBIND11_ENUM_OP_CONV_ARITHMETIC
3036+
#undef PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE
3037+
#undef PYBIND11_ENUM_OP_STRICT_ARITHMETIC
3038+
29803039
def("__int__", [](Type value) { return (Scalar) value; }, pos_only());
29813040
def("__index__", [](Type value) { return (Scalar) value; }, pos_only());
29823041
attr("__setstate__") = cpp_function(

tests/extra_python_package/test_files.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"include/pybind11/detail/descr.h",
8484
"include/pybind11/detail/dynamic_raw_ptr_cast_if_possible.h",
8585
"include/pybind11/detail/function_record_pyobject.h",
86+
"include/pybind11/detail/function_ref.h",
8687
"include/pybind11/detail/holder_caster_foreign_helpers.h",
8788
"include/pybind11/detail/init.h",
8889
"include/pybind11/detail/internals.h",

tests/test_copy_move.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ def test_move_and_copy_loads():
7070

7171
assert c_m.copy_assignments + c_m.copy_constructions == 0
7272
assert c_m.move_assignments == 6
73-
assert c_m.move_constructions == 9
73+
assert c_m.move_constructions == 21
7474
assert c_mc.copy_assignments + c_mc.copy_constructions == 0
7575
assert c_mc.move_assignments == 5
76-
assert c_mc.move_constructions == 8
76+
assert c_mc.move_constructions == 18
7777
assert c_c.copy_assignments == 4
78-
assert c_c.copy_constructions == 6
78+
assert c_c.copy_constructions == 14
7979
assert c_m.alive() + c_mc.alive() + c_c.alive() == 0
8080

8181

@@ -103,12 +103,12 @@ def test_move_and_copy_load_optional():
103103

104104
assert c_m.copy_assignments + c_m.copy_constructions == 0
105105
assert c_m.move_assignments == 2
106-
assert c_m.move_constructions == 5
106+
assert c_m.move_constructions == 9
107107
assert c_mc.copy_assignments + c_mc.copy_constructions == 0
108108
assert c_mc.move_assignments == 2
109-
assert c_mc.move_constructions == 5
109+
assert c_mc.move_constructions == 9
110110
assert c_c.copy_assignments == 2
111-
assert c_c.copy_constructions == 5
111+
assert c_c.copy_constructions == 9
112112
assert c_m.alive() + c_mc.alive() + c_c.alive() == 0
113113

114114

0 commit comments

Comments
 (0)