From aab3e0388c65240c6e6cb35dc6cd151a88075c79 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 25 Mar 2025 16:10:51 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- kernels/portable/cpu/op_convolution.cpp | 2 +- kernels/portable/cpu/op_cumsum.cpp | 2 +- kernels/portable/cpu/util/dtype_util.h | 201 ++++++++++--------- kernels/portable/cpu/util/elementwise_util.h | 35 ++-- 4 files changed, 123 insertions(+), 117 deletions(-) diff --git a/kernels/portable/cpu/op_convolution.cpp b/kernels/portable/cpu/op_convolution.cpp index 44da2cc0f1f..b5eb8d1f5db 100644 --- a/kernels/portable/cpu/op_convolution.cpp +++ b/kernels/portable/cpu/op_convolution.cpp @@ -414,7 +414,7 @@ Tensor& convolution_out( ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { const auto load_bias = bias.has_value() - ? utils::internal::get_load_to_common_fn( + ? utils::internal::get_load_to_compute_fn( bias.value(), utils::SupportedTensorDtypes::REALHBF16) : nullptr; convolution_wrapper( diff --git a/kernels/portable/cpu/op_cumsum.cpp b/kernels/portable/cpu/op_cumsum.cpp index 2faa67433d4..1f4aa5c458e 100644 --- a/kernels/portable/cpu/op_cumsum.cpp +++ b/kernels/portable/cpu/op_cumsum.cpp @@ -113,7 +113,7 @@ Tensor& cumsum_out( ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] { const auto load_self = - utils::internal::get_load_to_common_fn( + utils::internal::get_load_to_compute_fn( self, utils::SupportedTensorDtypes::REALHBBF16); cumsum_tensors(self, load_self, dim, out); }); diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 2bbd5de4577..e3cac54908e 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -26,189 +26,189 @@ void convert_and_store(From f, void* dst) { *reinterpret_cast(dst) = static_cast(f); } -template -using load_to_common_fn = CTYPE_COMMON (*)(const void*); +template +using load_to_compute_fn = CTYPE_COMPUTE (*)(const void*); -template -load_to_common_fn get_load_to_common_fn_realhbbf16( +template +load_to_compute_fn get_load_to_compute_fn_realhbbf16( const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; + CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_REALHBBF16_TYPES( t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; + result = internal::load_and_convert; }); return result; } -template -load_to_common_fn get_load_to_common_fn_realhbf16( +template +load_to_compute_fn get_load_to_compute_fn_realhbf16( const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; + CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_REALHBF16_TYPES( t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; + result = internal::load_and_convert; }); return result; } -template -load_to_common_fn get_load_to_common_fn_floathbf16( +template +load_to_compute_fn get_load_to_compute_fn_floathbf16( const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; + CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_FLOATHBF16_TYPES( t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; + result = internal::load_and_convert; }); return result; } -template -load_to_common_fn get_load_to_common_fn_intb(const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; +template +load_to_compute_fn get_load_to_compute_fn_intb(const Tensor& t) { + CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_INT_TYPES_AND( Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; + result = internal::load_and_convert; }); return result; } -template -load_to_common_fn get_load_to_common_fn_bool_or_byte( +template +load_to_compute_fn get_load_to_compute_fn_bool_or_byte( const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; + CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_TWO_TYPES( Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; + result = internal::load_and_convert; }); return result; } -template -load_to_common_fn get_load_to_common_fn_same_as_compute( +template +load_to_compute_fn get_load_to_compute_fn_same_as_compute( const Tensor& t) { - constexpr auto common_scalar_type = CppTypeToScalarType::value; + constexpr auto common_scalar_type = CppTypeToScalarType::value; ET_CHECK_MSG( t.scalar_type() == common_scalar_type, "Unhandled dtype %s for %s", ::executorch::runtime::toString(common_scalar_type), op_name); - return internal::load_and_convert; + return internal::load_and_convert; } template < - typename CTYPE_COMMON, + typename CTYPE_COMPUTE, const char* op_name, - std::enable_if_t, bool> = true> -load_to_common_fn get_load_to_common_fn_same_as_common( + std::enable_if_t, bool> = true> +load_to_compute_fn get_load_to_compute_fn_same_as_common( const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; + CTYPE_COMPUTE (*result)(const void*) = nullptr; ET_SWITCH_THREE_TYPES( Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() { - result = internal::load_and_convert; + result = internal::load_and_convert; }); return result; } template < - typename CTYPE_COMMON, + typename CTYPE_COMPUTE, const char* op_name, - std::enable_if_t, bool> = true> -load_to_common_fn get_load_to_common_fn_same_as_common( + std::enable_if_t, bool> = true> +load_to_compute_fn get_load_to_compute_fn_same_as_common( const Tensor& t) { - return get_load_to_common_fn_same_as_compute(t); + return get_load_to_compute_fn_same_as_compute(t); } -template -using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*); +template +using store_compute_to_tensor_fn = void (*)(CTYPE_COMPUTE, void*); -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; +template +store_compute_to_tensor_fn +get_store_compute_to_tensor_fn_realhbbf16(const Tensor& t) { + void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_REALHBBF16_TYPES( t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; + result = internal::convert_and_store; }); return result; } -template -store_common_to_tensor_fn get_store_common_to_tensor_fn_realhbf16( - const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; +template +store_compute_to_tensor_fn +get_store_compute_to_tensor_fn_realhbf16(const Tensor& t) { + void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_REALHBF16_TYPES( t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; + result = internal::convert_and_store; }); return result; } -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_floathbf16(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; +template +store_compute_to_tensor_fn +get_store_compute_to_tensor_fn_floathbf16(const Tensor& t) { + void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_FLOATHBF16_TYPES( t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; + result = internal::convert_and_store; }); return result; } -template -store_common_to_tensor_fn get_store_common_to_tensor_fn_intb( +template +store_compute_to_tensor_fn get_store_compute_to_tensor_fn_intb( const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; + void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_INT_TYPES_AND( Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; + result = internal::convert_and_store; }); return result; } -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; +template +store_compute_to_tensor_fn +get_store_compute_to_tensor_fn_bool_or_byte(const Tensor& t) { + void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_TWO_TYPES( Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; + result = internal::convert_and_store; }); return result; } -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) { - constexpr auto common_scalar_type = CppTypeToScalarType::value; +template +store_compute_to_tensor_fn +get_store_compute_to_tensor_fn_same_as_compute(const Tensor& t) { + constexpr auto common_scalar_type = CppTypeToScalarType::value; ET_CHECK_MSG( t.scalar_type() == common_scalar_type, "Unhandled dtype %s for %s", ::executorch::runtime::toString(common_scalar_type), op_name); - return internal::convert_and_store; + return internal::convert_and_store; } template < - typename CTYPE_COMMON, + typename CTYPE_COMPUTE, const char* op_name, - std::enable_if_t, bool> = true> -store_common_to_tensor_fn -get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; + std::enable_if_t, bool> = true> +store_compute_to_tensor_fn +get_store_compute_to_tensor_fn_same_as_common(const Tensor& t) { + void (*result)(CTYPE_COMPUTE, void*) = nullptr; ET_SWITCH_THREE_TYPES( Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() { - result = internal::convert_and_store; + result = internal::convert_and_store; }); return result; } template < - typename CTYPE_COMMON, + typename CTYPE_COMPUTE, const char* op_name, - std::enable_if_t, bool> = true> -store_common_to_tensor_fn -get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { - return get_store_common_to_tensor_fn_same_as_compute( + std::enable_if_t, bool> = true> +store_compute_to_tensor_fn +get_store_compute_to_tensor_fn_same_as_common(const Tensor& t) { + return get_store_compute_to_tensor_fn_same_as_compute( t); } @@ -220,59 +220,64 @@ enum class SupportedTensorDtypes { FLOATHBF16, INTB, BOOL_OR_BYTE, + // DEPRECATED: not likely to be correct; use SAME_AS_COMMON. SAME_AS_COMPUTE, SAME_AS_COMMON, }; namespace internal { -template -load_to_common_fn get_load_to_common_fn( +template +load_to_compute_fn get_load_to_compute_fn( const Tensor& t, SupportedTensorDtypes dtypes) { switch (dtypes) { case SupportedTensorDtypes::REALHBBF16: - return get_load_to_common_fn_realhbbf16(t); + return get_load_to_compute_fn_realhbbf16(t); case SupportedTensorDtypes::REALHBF16: - return get_load_to_common_fn_realhbf16(t); + return get_load_to_compute_fn_realhbf16(t); case SupportedTensorDtypes::FLOATHBF16: - return get_load_to_common_fn_realhbf16(t); + return get_load_to_compute_fn_realhbf16(t); case SupportedTensorDtypes::INTB: - return get_load_to_common_fn_intb(t); + return get_load_to_compute_fn_intb(t); case SupportedTensorDtypes::BOOL_OR_BYTE: - return get_load_to_common_fn_bool_or_byte(t); + return get_load_to_compute_fn_bool_or_byte(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: - return get_load_to_common_fn_same_as_compute(t); + return get_load_to_compute_fn_same_as_compute(t); case SupportedTensorDtypes::SAME_AS_COMMON: - return get_load_to_common_fn_same_as_common(t); + return get_load_to_compute_fn_same_as_common(t); } ET_CHECK(false); return nullptr; } -template -store_common_to_tensor_fn get_store_common_to_tensor_fn( +template +store_compute_to_tensor_fn get_store_compute_to_tensor_fn( const Tensor& t, SupportedTensorDtypes dtypes) { switch (dtypes) { case SupportedTensorDtypes::REALHBBF16: - return get_store_common_to_tensor_fn_realhbbf16(t); + return get_store_compute_to_tensor_fn_realhbbf16( + t); case SupportedTensorDtypes::REALHBF16: - return get_store_common_to_tensor_fn_realhbf16(t); + return get_store_compute_to_tensor_fn_realhbf16( + t); case SupportedTensorDtypes::FLOATHBF16: - return get_store_common_to_tensor_fn_floathbf16(t); + return get_store_compute_to_tensor_fn_floathbf16( + t); case SupportedTensorDtypes::INTB: - return get_store_common_to_tensor_fn_intb(t); + return get_store_compute_to_tensor_fn_intb(t); case SupportedTensorDtypes::BOOL_OR_BYTE: - return get_store_common_to_tensor_fn_bool_or_byte( - t); + return get_store_compute_to_tensor_fn_bool_or_byte< + CTYPE_COMPUTE, + op_name>(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: - return get_store_common_to_tensor_fn_same_as_compute< - CTYPE_COMMON, + return get_store_compute_to_tensor_fn_same_as_compute< + CTYPE_COMPUTE, op_name>(t); case SupportedTensorDtypes::SAME_AS_COMMON: { - return get_store_common_to_tensor_fn_same_as_common< - CTYPE_COMMON, + return get_store_compute_to_tensor_fn_same_as_common< + CTYPE_COMPUTE, op_name>(t); } } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index f5932069005..de97d736fbd 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -52,7 +52,7 @@ inline int64_t scalar_to(const Scalar& s) { namespace internal { template < - typename CTYPE_COMMON, + typename CTYPE_COMPUTE, const char* op_name, typename Op, typename... Args> @@ -66,7 +66,7 @@ inline void apply_elementwise_fn( (std::is_same_v> && ...)); constexpr auto kNumInputs = sizeof...(inputs); - constexpr auto compute_type = CppTypeToScalarType::value; + constexpr auto compute_type = CppTypeToScalarType::value; const auto check_input_dtype = [](auto input, auto compute_type) { return internal::check_tensor_dtype( *input.first, input.second, compute_type); @@ -78,19 +78,19 @@ inline void apply_elementwise_fn( InvalidArgument, ); struct InputInfo { - load_to_common_fn load_to_common; + load_to_compute_fn load_to_compute; const char* data_ptr; ssize_t element_size; }; std::array inputs_info = {(InputInfo{ - internal::get_load_to_common_fn( + internal::get_load_to_compute_fn( *inputs.first, inputs.second), reinterpret_cast(inputs.first->const_data_ptr()), inputs.first->element_size(), })...}; - const auto store_common_to_out = - internal::get_store_common_to_tensor_fn( + const auto store_compute_to_out = + internal::get_store_compute_to_tensor_fn( out, out_dtypes); char* const data_out = reinterpret_cast(out.mutable_data_ptr()); const auto out_element_size = out.element_size(); @@ -106,21 +106,22 @@ inline void apply_elementwise_fn( begin_it += begin; for (; (*begin_it)[0] < end; ++begin_it) { const auto& indexes = *begin_it; - std::array loaded_inputs; + std::array loaded_inputs; for (const auto idx : c10::irange(kNumInputs)) { const auto& input_info = inputs_info[idx]; - loaded_inputs[idx] = input_info.load_to_common( + loaded_inputs[idx] = input_info.load_to_compute( &input_info .data_ptr[indexes[idx + 1] * input_info.element_size]); } auto result = std::apply(compute_fun, loaded_inputs); - store_common_to_out(result, &data_out[indexes[0] * out_element_size]); + store_compute_to_out( + result, &data_out[indexes[0] * out_element_size]); } }); } } // namespace internal -template +template inline void apply_unitensor_elementwise_fn( const Op& compute_fun, KernelRuntimeContext& ctx, @@ -128,7 +129,7 @@ inline void apply_unitensor_elementwise_fn( SupportedTensorDtypes a_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn( compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); } @@ -137,7 +138,7 @@ inline void apply_unitensor_elementwise_fn( * perform a computation and write to the corresponding element of the output. * Tensor broadcasting is applied wherever it is required. */ -template +template inline void apply_bitensor_elementwise_fn( const Op& compute_fun, KernelRuntimeContext& ctx, @@ -147,7 +148,7 @@ inline void apply_bitensor_elementwise_fn( SupportedTensorDtypes b_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn( compute_fun, ctx, out, @@ -163,7 +164,7 @@ inline void apply_bitensor_elementwise_fn( * * In order to mitigate build time cost (straightforwardly |CTYPE_A| * * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun - * are passed as CTYPE_COMMON. + * are passed as CTYPE_COMPUTE. * * Each tensor's supported dtypes set must be provided. The tensor * will be checked to ensure that its dtype falls into that set. @@ -174,9 +175,9 @@ inline void apply_bitensor_elementwise_fn( * following: * * static constexpr const char op_name[] = "my_op"; - * apply_ternary_elementwise_fn. + * apply_ternary_elementwise_fn. */ -template +template inline void apply_tritensor_elementwise_fn( const Op& compute_fun, KernelRuntimeContext& ctx, @@ -188,7 +189,7 @@ inline void apply_tritensor_elementwise_fn( SupportedTensorDtypes c_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn( compute_fun, ctx, out, From 6feb1d54ec12e67c5c9b5ac1f1953c94dd6ac742 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 25 Mar 2025 16:26:02 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- kernels/portable/cpu/util/elementwise_util.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index de97d736fbd..206be87f98e 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -119,7 +119,6 @@ inline void apply_elementwise_fn( } }); } -} // namespace internal template inline void apply_unitensor_elementwise_fn( @@ -206,6 +205,14 @@ inline ScalarType get_compute_type(ScalarType& common_type) { } return compute_type; } +} // namespace internal + +// DEPRECATED: these APIs should not have been stabilized for external +// use as they are undergoing active development. +using internal::apply_bitensor_elementwise_fn; +using internal::apply_tritensor_elementwise_fn; +using internal::apply_unitensor_elementwise_fn; +using internal::get_compute_type; } // namespace utils } // namespace native