diff --git a/exir/dialects/edge/edge.yaml b/exir/dialects/edge/edge.yaml index 039490a8397..6970f7b00cc 100644 --- a/exir/dialects/edge/edge.yaml +++ b/exir/dialects/edge/edge.yaml @@ -3217,6 +3217,423 @@ other: T6 __ret_0: T17 +- func: aten::elu + namespace: edge + inherits: aten::elu + type_alias: + T0: [Bool] + T1: [Bool, Float, Int] + T2: [Double] + T3: [Float] + T4: [Half] + T5: [Int] + type_constraint: + - self: T2 + alpha: T0 + scale: T0 + input_scale: T1 + __ret_0: T2 + - self: T2 + alpha: T0 + scale: T1 + input_scale: T0 + __ret_0: T2 + - self: T2 + alpha: T0 + scale: T1 + input_scale: T3 + __ret_0: T2 + - self: T2 + alpha: T0 + scale: T1 + input_scale: T5 + __ret_0: T2 + - self: T2 + alpha: T0 + scale: T3 + input_scale: T1 + __ret_0: T2 + - self: T2 + alpha: T0 + scale: T5 + input_scale: T1 + __ret_0: T2 + - self: T2 + alpha: T1 + scale: T0 + input_scale: T0 + __ret_0: T2 + - self: T2 + alpha: T1 + scale: T0 + input_scale: T3 + __ret_0: T2 + - self: T2 + alpha: T1 + scale: T0 + input_scale: T5 + __ret_0: T2 + - self: T2 + alpha: T1 + scale: T3 + input_scale: T0 + __ret_0: T2 + - self: T2 + alpha: T1 + scale: T3 + input_scale: T3 + __ret_0: T2 + - self: T2 + alpha: T1 + scale: T3 + input_scale: T5 + __ret_0: T2 + - self: T2 + alpha: T1 + scale: T5 + input_scale: T0 + __ret_0: T2 + - self: T2 + alpha: T1 + scale: T5 + input_scale: T3 + __ret_0: T2 + - self: T2 + alpha: T1 + scale: T5 + input_scale: T5 + __ret_0: T2 + - self: T2 + alpha: T3 + scale: T0 + input_scale: T1 + __ret_0: T2 + - self: T2 + alpha: T3 + scale: T1 + input_scale: T0 + __ret_0: T2 + - self: T2 + alpha: T3 + scale: T1 + input_scale: T3 + __ret_0: T2 + - self: T2 + alpha: T3 + scale: T1 + input_scale: T5 + __ret_0: T2 + - self: T2 + alpha: T3 + scale: T3 + input_scale: T1 + __ret_0: T2 + - self: T2 + alpha: T3 + scale: T5 + input_scale: T1 + __ret_0: T2 + - self: T2 + alpha: T5 + scale: T0 + input_scale: T1 + __ret_0: T2 + - self: T2 + alpha: T5 + scale: T1 + input_scale: T0 + __ret_0: T2 + - self: T2 + alpha: T5 + scale: T1 + input_scale: T3 + __ret_0: T2 + - self: T2 + alpha: T5 + scale: T1 + input_scale: T5 + __ret_0: T2 + - self: T2 + alpha: T5 + scale: T3 + input_scale: T1 + __ret_0: T2 + - self: T2 + alpha: T5 + scale: T5 + input_scale: T1 + __ret_0: T2 + - self: T3 + alpha: T0 + scale: T0 + input_scale: T1 + __ret_0: T3 + - self: T3 + alpha: T0 + scale: T1 + input_scale: T0 + __ret_0: T3 + - self: T3 + alpha: T0 + scale: T1 + input_scale: T3 + __ret_0: T3 + - self: T3 + alpha: T0 + scale: T1 + input_scale: T5 + __ret_0: T3 + - self: T3 + alpha: T0 + scale: T3 + input_scale: T1 + __ret_0: T3 + - self: T3 + alpha: T0 + scale: T5 + input_scale: T1 + __ret_0: T3 + - self: T3 + alpha: T1 + scale: T0 + input_scale: T0 + __ret_0: T3 + - self: T3 + alpha: T1 + scale: T0 + input_scale: T3 + __ret_0: T3 + - self: T3 + alpha: T1 + scale: T0 + input_scale: T5 + __ret_0: T3 + - self: T3 + alpha: T1 + scale: T3 + input_scale: T0 + __ret_0: T3 + - self: T3 + alpha: T1 + scale: T3 + input_scale: T3 + __ret_0: T3 + - self: T3 + alpha: T1 + scale: T3 + input_scale: T5 + __ret_0: T3 + - self: T3 + alpha: T1 + scale: T5 + input_scale: T0 + __ret_0: T3 + - self: T3 + alpha: T1 + scale: T5 + input_scale: T3 + __ret_0: T3 + - self: T3 + alpha: T1 + scale: T5 + input_scale: T5 + __ret_0: T3 + - self: T3 + alpha: T3 + scale: T0 + input_scale: T1 + __ret_0: T3 + - self: T3 + alpha: T3 + scale: T1 + input_scale: T0 + __ret_0: T3 + - self: T3 + alpha: T3 + scale: T1 + input_scale: T3 + __ret_0: T3 + - self: T3 + alpha: T3 + scale: T1 + input_scale: T5 + __ret_0: T3 + - self: T3 + alpha: T3 + scale: T3 + input_scale: T1 + __ret_0: T3 + - self: T3 + alpha: T3 + scale: T5 + input_scale: T1 + __ret_0: T3 + - self: T3 + alpha: T5 + scale: T0 + input_scale: T1 + __ret_0: T3 + - self: T3 + alpha: T5 + scale: T1 + input_scale: T0 + __ret_0: T3 + - self: T3 + alpha: T5 + scale: T1 + input_scale: T3 + __ret_0: T3 + - self: T3 + alpha: T5 + scale: T1 + input_scale: T5 + __ret_0: T3 + - self: T3 + alpha: T5 + scale: T3 + input_scale: T1 + __ret_0: T3 + - self: T3 + alpha: T5 + scale: T5 + input_scale: T1 + __ret_0: T3 + - self: T4 + alpha: T0 + scale: T0 + input_scale: T1 + __ret_0: T4 + - self: T4 + alpha: T0 + scale: T1 + input_scale: T0 + __ret_0: T4 + - self: T4 + alpha: T0 + scale: T1 + input_scale: T3 + __ret_0: T4 + - self: T4 + alpha: T0 + scale: T1 + input_scale: T5 + __ret_0: T4 + - self: T4 + alpha: T0 + scale: T3 + input_scale: T1 + __ret_0: T4 + - self: T4 + alpha: T0 + scale: T5 + input_scale: T1 + __ret_0: T4 + - self: T4 + alpha: T1 + scale: T0 + input_scale: T0 + __ret_0: T4 + - self: T4 + alpha: T1 + scale: T0 + input_scale: T3 + __ret_0: T4 + - self: T4 + alpha: T1 + scale: T0 + input_scale: T5 + __ret_0: T4 + - self: T4 + alpha: T1 + scale: T3 + input_scale: T0 + __ret_0: T4 + - self: T4 + alpha: T1 + scale: T3 + input_scale: T3 + __ret_0: T4 + - self: T4 + alpha: T1 + scale: T3 + input_scale: T5 + __ret_0: T4 + - self: T4 + alpha: T1 + scale: T5 + input_scale: T0 + __ret_0: T4 + - self: T4 + alpha: T1 + scale: T5 + input_scale: T3 + __ret_0: T4 + - self: T4 + alpha: T1 + scale: T5 + input_scale: T5 + __ret_0: T4 + - self: T4 + alpha: T3 + scale: T0 + input_scale: T1 + __ret_0: T4 + - self: T4 + alpha: T3 + scale: T1 + input_scale: T0 + __ret_0: T4 + - self: T4 + alpha: T3 + scale: T1 + input_scale: T3 + __ret_0: T4 + - self: T4 + alpha: T3 + scale: T1 + input_scale: T5 + __ret_0: T4 + - self: T4 + alpha: T3 + scale: T3 + input_scale: T1 + __ret_0: T4 + - self: T4 + alpha: T3 + scale: T5 + input_scale: T1 + __ret_0: T4 + - self: T4 + alpha: T5 + scale: T0 + input_scale: T1 + __ret_0: T4 + - self: T4 + alpha: T5 + scale: T1 + input_scale: T0 + __ret_0: T4 + - self: T4 + alpha: T5 + scale: T1 + input_scale: T3 + __ret_0: T4 + - self: T4 + alpha: T5 + scale: T1 + input_scale: T5 + __ret_0: T4 + - self: T4 + alpha: T5 + scale: T3 + input_scale: T1 + __ret_0: T4 + - self: T4 + alpha: T5 + scale: T5 + input_scale: T1 + __ret_0: T4 + - func: aten::embedding namespace: edge inherits: aten::embedding diff --git a/exir/dialects/edge/op/sample_input.py b/exir/dialects/edge/op/sample_input.py index 23d87053c9e..3986cfd2d9f 100644 --- a/exir/dialects/edge/op/sample_input.py +++ b/exir/dialects/edge/op/sample_input.py @@ -424,6 +424,15 @@ ], "returns": [Return(ArgType.Tensor)], }, + "elu.default": { # (Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!) + "args": [ + InArg(ArgType.Tensor), + InArg(ArgType.Scalar), + InArg(ArgType.Scalar), + InArg(ArgType.Scalar), + ], + "returns": [Return(ArgType.Tensor)], + }, "embedding.default": { # (Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor "args": [ InArg(ArgType.Tensor), diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 7069f9140ab..a8fa6611478 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -141,6 +141,8 @@ - op: div.out_mode +- op: elu.out + - op: embedding.out - op: empty.out diff --git a/kernels/portable/cpu/op_elu.cpp b/kernels/portable/cpu/op_elu.cpp new file mode 100644 index 00000000000..d4846fb1bfb --- /dev/null +++ b/kernels/portable/cpu/op_elu.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +namespace torch::executor::native { + +Tensor& elu_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + Tensor& out) { + ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_floating_type(in), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out); + + static constexpr const char op_name[] = "elu.out"; + ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() { + using MathT = std:: + conditional_t, float, CTYPE>; + MathT math_alpha = 0; + MathT math_scale = 0; + MathT math_input_scale = 0; + ET_EXTRACT_SCALAR(alpha, math_alpha); + ET_EXTRACT_SCALAR(scale, math_scale); + ET_EXTRACT_SCALAR(input_scale, math_input_scale); + const auto negcoef = math_alpha * math_scale; + utils::apply_unitensor_elementwise_fn( + [negcoef, math_scale, math_input_scale](auto x) { + return MathT(x) <= MathT(0) + ? std::expm1(MathT(x) * math_input_scale) * negcoef + : MathT(x) * math_scale; + }, + ctx, + in, + utils::SupportedTensorDtypes::FLOATHBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); + return out; +} + +} // namespace torch::executor::native diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 29dfe8b1a0c..5e45a210a70 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -329,6 +329,11 @@ - arg_meta: null kernel_name: torch::executor::eq_tensor_out +- op: elu.out + kernels: + - arg_meta: null + kernel_name: torch::executor::elu_out + - op: erf.out kernels: - arg_meta: null diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index b9f48f0c9a1..42578acbedd 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -135,6 +135,7 @@ set(all_test_sources "op_detach_copy_test.cpp" "op_diagonal_copy_test.cpp" "op_div_test.cpp" + "op_elu_test.cpp" "op_embedding_test.cpp" "op_empty_test.cpp" "op_eq_test.cpp" diff --git a/kernels/test/op_elu_test.cpp b/kernels/test/op_elu_test.cpp new file mode 100644 index 00000000000..73ee8ac31a7 --- /dev/null +++ b/kernels/test/op_elu_test.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using executorch::aten::Scalar; +using executorch::aten::ScalarType; +using executorch::aten::string_view; +using executorch::aten::Tensor; +using torch::executor::testing::TensorFactory; + +class OpEluTest : public OperatorTest { + protected: + Tensor& op_elu_out( + const Tensor& self, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + Tensor& out) { + return torch::executor::aten::elu_outf( + context_, self, alpha, scale, input_scale, out); + } + + template + void test_elu_execution() { + TensorFactory tf; + + const std::vector sizes = {3, 2}; + + Tensor in = tf.make(sizes, /*data=*/{-0.125, -0.25, -1, 0, 1.25, 100}); + + Tensor out = tf.zeros(sizes); + + // Run full gelu. + op_elu_out(in, 1.25, 1, 1, out); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE( + out, + tf.make( + sizes, + /*data=*/ + {-0.146879, -0.276499, -0.790151, 0, 1.25, 100})); + } + + template + void test_integer_elu_dies() { + TensorFactory tf; + + Tensor in = tf.ones({1}); + Tensor out = tf.ones({1}); + ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(in, 1, 1, 1, out)); + } +}; + +TEST_F(OpEluTest, Basic) { +#define TEST_ENTRY(ctype, dtype) test_elu_execution(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpEluTest, UnhandledDtypeDies) { +#define TEST_ENTRY(ctype, dtype) test_integer_elu_dies(); + ET_FORALL_INT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpEluTest, MismatchedOutputDtypeDies) { + // Two different dtypes. This test uses two types with the same size to + // demonstrate that the ScalarType itself matters, not the size of the + // tensor elements. + TensorFactory tf_float; + TensorFactory tf_double; + + const std::vector sizes = {2, 2}; + + Tensor a = tf_float.ones(sizes); + + // Destination with a dtype different from the input. + Tensor out = tf_double.zeros(sizes); + + ET_EXPECT_KERNEL_FAILURE(context_, op_elu_out(a, 1, 1, 1, out)); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 18ab0ac2e28..3824551a46b 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -215,6 +215,7 @@ def define_common_targets(): _common_op_test("op_detach_copy_test", ["aten", "portable"]) _common_op_test("op_diagonal_copy_test", ["aten", "portable"]) _common_op_test("op_div_test", ["aten", "portable", "optimized"]) + _common_op_test("op_elu_test", ["aten", "portable"]) _common_op_test("op_embedding_test", ["aten", "portable"]) _common_op_test("op_empty_test", ["aten", "portable"]) _common_op_test("op_eq_test", ["aten", "portable"]) diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index b56413b92f4..a1ffdc1eed3 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -482,6 +482,13 @@ ATEN_OPS = ( ":scalar_utils", ], ), + op_target( + name = "op_elu", + deps = [ + ":scalar_utils", + "//executorch/kernels/portable/cpu/util:elementwise_util", + ], + ), op_target( name = "op_embedding", deps = [