From c55ea47d062683d67a2dad531f4c4cc8bd9ba0f0 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 8 May 2024 13:48:29 -0700 Subject: [PATCH 1/2] Use compile-time promotion to reduce max/min size & build time (#3459) Summary: Yet another smaller pair of ops. Reviewed By: manuelcandales Differential Revision: D56807402 --- kernels/portable/cpu/op_maximum.cpp | 68 ++++++++++++++++++++++------ kernels/portable/cpu/op_minimum.cpp | 69 ++++++++++++++++++++++------- 2 files changed, 108 insertions(+), 29 deletions(-) diff --git a/kernels/portable/cpu/op_maximum.cpp b/kernels/portable/cpu/op_maximum.cpp index 3e34035d5f6..4091f2cf8ca 100644 --- a/kernels/portable/cpu/op_maximum.cpp +++ b/kernels/portable/cpu/op_maximum.cpp @@ -20,6 +20,50 @@ const T& max(const T& a, const T& b) { return (b > a) ? b : a; } +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MaximumInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MaximumInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = max(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MaximumInner + : public ReportCanCastBug {}; + } // namespace Tensor& maximum_out( @@ -44,20 +88,16 @@ Tensor& maximum_out( ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() { ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES(common_type, ctx, "maximum.out", CTYPE_IN, [&]() { - ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = max(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() { + MaximumInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); }); }); }); diff --git a/kernels/portable/cpu/op_minimum.cpp b/kernels/portable/cpu/op_minimum.cpp index 767a2c4ca59..7c106a63c4f 100644 --- a/kernels/portable/cpu/op_minimum.cpp +++ b/kernels/portable/cpu/op_minimum.cpp @@ -20,6 +20,50 @@ const T& min(const T& a, const T& b) { return (b < a) ? b : a; } +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MinimumInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MinimumInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = min(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MinimumInner + : public ReportCanCastBug {}; + } // namespace Tensor& minimum_out( @@ -44,22 +88,17 @@ Tensor& minimum_out( ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "minimum.out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "minimum.out", CTYPE_B, [&]() { + using CTYPE_IN = + typename torch::executor::promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "minimum.out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = min(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() { + MinimumInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); }); }); }); From 924d2097c5bc8954f13a823da2435eb4125ce7c2 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 8 May 2024 13:48:29 -0700 Subject: [PATCH 2/2] Use compile-time promotion to reduce floor_divide size & build time (#3455) Summary: Continuing rollout of this technique. Reviewed By: manuelcandales Differential Revision: D56827786 --- kernels/portable/cpu/op_floor_divide.cpp | 93 +++++++++++++------ .../core/exec_aten/util/scalar_type_util.h | 6 ++ 2 files changed, 70 insertions(+), 29 deletions(-) diff --git a/kernels/portable/cpu/op_floor_divide.cpp b/kernels/portable/cpu/op_floor_divide.cpp index 261f77ce617..0514df0ca25 100644 --- a/kernels/portable/cpu/op_floor_divide.cpp +++ b/kernels/portable/cpu/op_floor_divide.cpp @@ -20,6 +20,60 @@ namespace native { using Tensor = exec_aten::Tensor; using ScalarType = exec_aten::ScalarType; +namespace { +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FloorDivideInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FloorDivideInner { + static void + run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) { + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = utils::floor_divide(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&, bool&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FloorDivideInner + : public ReportCanCastBug {}; + +} // namespace + Tensor& floor_divide_out( RuntimeContext& ctx, const Tensor& a, @@ -46,36 +100,17 @@ Tensor& floor_divide_out( Bool, a_type, ctx, "floor_divide.out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "floor_divide.out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); ET_SWITCH_REAL_TYPES( - common_type, ctx, "floor_divide.out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [common_type, &div_by_zero_error]( - const CTYPE_A val_a, const CTYPE_B val_b) { - if (isIntegralType( - common_type, /*includeBool=*/true)) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = utils::floor_divide( - a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() { + FloorDivideInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out, div_by_zero_error); }); }); }); diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 595ed7a1c02..084289520aa 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -349,6 +349,12 @@ inline constexpr bool isIntegralType( t == exec_aten::ScalarType::Short); } +template +struct is_integral_type + : public std::integral_constant< + bool, + isIntegralType(CppTypeToScalarType::value, includeBool)> {}; + inline constexpr bool isFloatingType(exec_aten::ScalarType t) { return ( t == exec_aten::ScalarType::Double || t == exec_aten::ScalarType::Float ||