|
8 | 8 |
|
9 | 9 | #include <cmath>
|
10 | 10 |
|
| 11 | +#include <executorch/kernels/portable/cpu/util/elementwise_util.h> |
11 | 12 | #include <executorch/kernels/portable/cpu/util/functional_util.h>
|
12 | 13 | #include <executorch/runtime/kernel/kernel_includes.h>
|
13 | 14 |
|
@@ -35,21 +36,26 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
|
35 | 36 | out,
|
36 | 37 | "Failed to resize output tensor.");
|
37 | 38 |
|
38 |
| - ScalarType in_type = in.scalar_type(); |
39 |
| - ScalarType out_type = out.scalar_type(); |
40 |
| - ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() { |
41 |
| - ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() { |
42 |
| - apply_unary_map_fn( |
43 |
| - [](const CTYPE_IN val_in) { |
44 |
| - // perform math in double to preserve precision |
45 |
| - double in_casted = static_cast<double>(val_in); |
46 |
| - double out_val = 1.0 / (1.0 + exp(-in_casted)); |
47 |
| - return static_cast<CTYPE_OUT>(out_val); |
48 |
| - }, |
49 |
| - in.const_data_ptr<CTYPE_IN>(), |
50 |
| - out.mutable_data_ptr<CTYPE_OUT>(), |
51 |
| - in.numel()); |
52 |
| - }); |
| 39 | + ScalarType compute_type = |
| 40 | + executorch::runtime::isFloatingType(in.scalar_type()) ? in.scalar_type() |
| 41 | + : ScalarType::Float; |
| 42 | + compute_type = utils::get_compute_type(compute_type); |
| 43 | + |
| 44 | + // @lint-ignore CLANGTIDY facebook-hte-CArray |
| 45 | + static constexpr const char op_name[] = "sigmoid.out"; |
| 46 | + |
| 47 | + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { |
| 48 | + utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>( |
| 49 | + [](const CTYPE_COMPUTE val_in) { |
| 50 | + CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) / |
| 51 | + (static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in)); |
| 52 | + return out_val; |
| 53 | + }, |
| 54 | + ctx, |
| 55 | + in, |
| 56 | + utils::SupportedTensorDtypes::REALHBBF16, |
| 57 | + out, |
| 58 | + utils::SupportedTensorDtypes::FLOATHBF16); |
53 | 59 | });
|
54 | 60 |
|
55 | 61 | return out;
|
|
0 commit comments