Skip to content

Commit 6968d01

Browse files
authored
[Executorch][Portable] Dont upcast to double for sigmoid
Differential Revision: D65928920 Pull Request resolved: #6892
1 parent dc41596 commit 6968d01

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

kernels/portable/cpu/op_sigmoid.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cmath>
1010

11+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1112
#include <executorch/kernels/portable/cpu/util/functional_util.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

@@ -35,21 +36,26 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3536
out,
3637
"Failed to resize output tensor.");
3738

38-
ScalarType in_type = in.scalar_type();
39-
ScalarType out_type = out.scalar_type();
40-
ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() {
41-
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() {
42-
apply_unary_map_fn(
43-
[](const CTYPE_IN val_in) {
44-
// perform math in double to preserve precision
45-
double in_casted = static_cast<double>(val_in);
46-
double out_val = 1.0 / (1.0 + exp(-in_casted));
47-
return static_cast<CTYPE_OUT>(out_val);
48-
},
49-
in.const_data_ptr<CTYPE_IN>(),
50-
out.mutable_data_ptr<CTYPE_OUT>(),
51-
in.numel());
52-
});
39+
ScalarType compute_type =
40+
executorch::runtime::isFloatingType(in.scalar_type()) ? in.scalar_type()
41+
: ScalarType::Float;
42+
compute_type = utils::get_compute_type(compute_type);
43+
44+
// @lint-ignore CLANGTIDY facebook-hte-CArray
45+
static constexpr const char op_name[] = "sigmoid.out";
46+
47+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48+
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
49+
[](const CTYPE_COMPUTE val_in) {
50+
CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) /
51+
(static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in));
52+
return out_val;
53+
},
54+
ctx,
55+
in,
56+
utils::SupportedTensorDtypes::REALHBBF16,
57+
out,
58+
utils::SupportedTensorDtypes::FLOATHBF16);
5359
});
5460

5561
return out;

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,9 @@ ATEN_OPS = (
10741074
name = "op_sigmoid",
10751075
deps = [
10761076
"//executorch/kernels/portable/cpu/util:functional_util",
1077+
"//executorch/kernels/portable/cpu/util:elementwise_util",
1078+
"//executorch/kernels/portable/cpu/util:broadcast_util",
1079+
"//executorch/kernels/portable/cpu/util:dtype_util",
10771080
],
10781081
),
10791082
op_target(

0 commit comments

Comments
 (0)