Skip to content

Commit 8eabd3e

Browse files
committed
Update on "[Executorch][Portable] Dont upcast to double for sigmoid"
Upcasting to double for compute precision may not be aten compliant. Reason for internal test change: Apparently running on broadwell CPU vs test runner with Cooper lake gives different results for this change. Without this change: Both broadwell and Cooper lake will produce "Once upon a time, there was a little" With this change: Broadwell still produces "Once upon a time, there was a little", while Cooperlake produces "Once upon a time, there was a girl". So one possibility is that that some XNNPACK kernel for Cooper lake is produces slightly different numerical result that propagates through. Still landing this change since upcasting to double for compute, does not seem necessary. Differential Revision: [D65928920](https://our.internmc.facebook.com/intern/diff/D65928920/) [ghstack-poisoned]
1 parent 239e41c commit 8eabd3e

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

kernels/portable/cpu/op_sigmoid.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,26 @@ Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
3636
out,
3737
"Failed to resize output tensor.");
3838

39-
ScalarType common_type = in.scalar_type();
40-
ScalarType compute_type = utils::get_compute_type(common_type);
41-
// For integer types, we need to promote to the next higher float type
42-
if (compute_type != ScalarType::Float && compute_type != ScalarType::Double) {
43-
compute_type = ScalarType::Float;
44-
}
39+
ScalarType compute_type =
40+
executorch::runtime::isFloatingType(in.scalar_type()) ? in.scalar_type()
41+
: ScalarType::Float;
42+
compute_type = utils::get_compute_type(compute_type);
4543

4644
// @lint-ignore CLANGTIDY facebook-hte-CArray
4745
static constexpr const char op_name[] = "sigmoid.out";
4846

49-
ET_KERNEL_CHECK(
50-
ctx, executorch::runtime::isRealType(compute_type), InvalidArgument, out);
51-
52-
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
47+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
5348
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
5449
[](const CTYPE_COMPUTE val_in) {
55-
CTYPE_COMPUTE in_casted = static_cast<CTYPE_COMPUTE>(val_in);
5650
CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) /
57-
(static_cast<CTYPE_COMPUTE>(1.0) + exp(-in_casted));
51+
(static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in));
5852
return out_val;
5953
},
6054
ctx,
6155
in,
6256
utils::SupportedTensorDtypes::REALHBBF16,
6357
out,
64-
utils::SupportedTensorDtypes::REALHBBF16);
58+
utils::SupportedTensorDtypes::FLOATHBF16);
6559
});
6660

6761
return out;

0 commit comments

Comments
 (0)