From 29d6de9d2e63b567e242aea0b7949d7250f12b34 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 18 Mar 2025 17:32:16 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- .../cpu/pattern/unary_ufunc_realh.cpp | 19 ++++--- .../pattern/unary_ufunc_realhb_to_bool.cpp | 26 +++++----- .../unary_ufunc_realhbbf16_to_floathbf16.cpp | 27 +++++----- kernels/portable/cpu/util/dtype_util.cpp | 4 ++ kernels/portable/cpu/util/dtype_util.h | 50 +++++++++++++++++++ 5 files changed, 94 insertions(+), 32 deletions(-) diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realh.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realh.cpp index 16d847ace31..f7050e8410b 100644 --- a/kernels/portable/cpu/pattern/unary_ufunc_realh.cpp +++ b/kernels/portable/cpu/pattern/unary_ufunc_realh.cpp @@ -7,7 +7,7 @@ */ #include -#include +#include #include namespace torch { @@ -36,12 +36,19 @@ Tensor& unary_ufunc_realh( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] { - apply_unary_map_fn( + // TODO: this is broken for dtype_selective_build: this was + // __func__, which isn't the operator name. + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "unary_ufunc_realh"; + + ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] { + utils::apply_unitensor_elementwise_fn( [fn](const CTYPE val_in) { return static_cast(fn(val_in)); }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); + ctx, + in, + utils::SupportedTensorDtypes::REALH, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp index 367137ad02c..5a7332efc07 100644 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp +++ b/kernels/portable/cpu/pattern/unary_ufunc_realhb_to_bool.cpp @@ -7,7 +7,7 @@ */ #include -#include +#include #include namespace torch { @@ -30,25 +30,23 @@ Tensor& unary_ufunc_realhb_to_bool( out, "Failed to resize output tensor."); - ET_KERNEL_CHECK_MSG( - ctx, - out.scalar_type() == executorch::aten::ScalarType::Bool, - InvalidArgument, - out, - "Expected out tensor to have dtype Bool, but got %" PRId8 " instead.", - static_cast(out.scalar_type())); - ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); const auto in_type = in.scalar_type(); - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] { - apply_unary_map_fn( + // TODO: this is broken for dtype_selective_build: this was + // __func__, which isn't the operator name. + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "unary_ufunc_realhb_to_bool"; + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE_IN, [&] { + utils::apply_unitensor_elementwise_fn( [fn](const CTYPE_IN val_in) { return fn(val_in); }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::BOOL); }); return out; diff --git a/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp b/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp index 602b5b1bfd2..3dcdbd4050c 100644 --- a/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp +++ b/kernels/portable/cpu/pattern/unary_ufunc_realhbbf16_to_floathbf16.cpp @@ -7,7 +7,7 @@ */ #include -#include +#include #include namespace torch { @@ -38,17 +38,20 @@ Tensor& unary_ufunc_realhbbf16_to_floathbf16( const auto in_type = in.scalar_type(); const auto out_type = out.scalar_type(); - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, __func__, CTYPE_IN, [&] { - ET_SWITCH_FLOATHBF16_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&] { - apply_unary_map_fn( - [fn](const CTYPE_IN val_in) { - CTYPE_OUT xi = static_cast(val_in); - return static_cast(fn(xi)); - }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); - }); + // TODO: this is broken for dtype_selective_build: this was + // __func__, which isn't the operator name. + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = + "unary_ufunc_realhbbf16_to_floathbf16"; + + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, op_name, CTYPE_IN, [&] { + utils::apply_unitensor_elementwise_fn( + [fn](const CTYPE_IN val_in) { return fn(val_in); }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::FLOATHBF16); }); return out; diff --git a/kernels/portable/cpu/util/dtype_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp index d240b9f83bc..81b1b203a54 100644 --- a/kernels/portable/cpu/util/dtype_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -23,10 +23,14 @@ bool check_tensor_dtype( return executorch::runtime::tensor_is_realhbbf16_type(t); case SupportedTensorDtypes::REALHBF16: return executorch::runtime::tensor_is_realhbf16_type(t); + case SupportedTensorDtypes::REALH: + return executorch::runtime::tensor_is_realh_type(t); case SupportedTensorDtypes::FLOATHBF16: return executorch::runtime::tensor_is_floating_type(t); case SupportedTensorDtypes::INTB: return executorch::runtime::tensor_is_integral_type(t, true); + case SupportedTensorDtypes::BOOL: + return executorch::runtime::tensor_is_type(t, ScalarType::Bool); case SupportedTensorDtypes::BOOL_OR_BYTE: return (executorch::runtime::tensor_is_type( t, ScalarType::Bool, ScalarType::Byte)); diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h index 59b82cdc51b..19bee220005 100644 --- a/kernels/portable/cpu/util/dtype_util.h +++ b/kernels/portable/cpu/util/dtype_util.h @@ -51,6 +51,15 @@ load_to_common_fn get_load_to_common_fn_realhbf16( return result; } +template +load_to_common_fn get_load_to_common_fn_realh(const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_REALH_TYPES(t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + template load_to_common_fn get_load_to_common_fn_floathbf16( const Tensor& t) { @@ -72,6 +81,16 @@ load_to_common_fn get_load_to_common_fn_intb(const Tensor& t) { return result; } +template +load_to_common_fn get_load_to_common_fn_bool(const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::load_and_convert; +} + template load_to_common_fn get_load_to_common_fn_bool_or_byte( const Tensor& t) { @@ -137,6 +156,16 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn_realhbf16( return result; } +template +store_common_to_tensor_fn get_store_common_to_tensor_fn_realh( + const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_REALH_TYPES(t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + template store_common_to_tensor_fn get_store_common_to_tensor_fn_floathbf16(const Tensor& t) { @@ -159,6 +188,17 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn_intb( return result; } +template +store_common_to_tensor_fn get_store_common_to_tensor_fn_bool( + const Tensor& t) { + ET_CHECK_MSG( + t.scalar_type() == ScalarType::Bool, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(t.scalar_type()), + op_name); + return internal::convert_and_store; +} + template store_common_to_tensor_fn get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { @@ -206,8 +246,10 @@ get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { enum class SupportedTensorDtypes { REALHBBF16, REALHBF16, + REALH, FLOATHBF16, INTB, + BOOL, BOOL_OR_BYTE, SAME_AS_COMPUTE, SAME_AS_COMMON, @@ -224,10 +266,14 @@ load_to_common_fn get_load_to_common_fn( return get_load_to_common_fn_realhbbf16(t); case SupportedTensorDtypes::REALHBF16: return get_load_to_common_fn_realhbf16(t); + case SupportedTensorDtypes::REALH: + return get_load_to_common_fn_realh(t); case SupportedTensorDtypes::FLOATHBF16: return get_load_to_common_fn_realhbf16(t); case SupportedTensorDtypes::INTB: return get_load_to_common_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_load_to_common_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_common_fn_bool_or_byte(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: @@ -248,10 +294,14 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn( return get_store_common_to_tensor_fn_realhbbf16(t); case SupportedTensorDtypes::REALHBF16: return get_store_common_to_tensor_fn_realhbf16(t); + case SupportedTensorDtypes::REALH: + return get_store_common_to_tensor_fn_realh(t); case SupportedTensorDtypes::FLOATHBF16: return get_store_common_to_tensor_fn_floathbf16(t); case SupportedTensorDtypes::INTB: return get_store_common_to_tensor_fn_intb(t); + case SupportedTensorDtypes::BOOL: + return get_store_common_to_tensor_fn_bool(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_common_to_tensor_fn_bool_or_byte( t);