From ac77b1c65eead78d70b6f8ba872f6419f3542b01 Mon Sep 17 00:00:00 2001 From: Hossein Askari Date: Mon, 27 Jan 2025 21:59:03 -0800 Subject: [PATCH] aten.leakyrelu.default in unary_ops Differential Revision: D68688186 --- .../vulkan/runtime/graph/ops/glsl/activations.h | 12 ++++++++++++ .../vulkan/runtime/graph/ops/glsl/unary_op.yaml | 2 ++ backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 13 +++++++++++++ backends/vulkan/test/op_tests/cases.py | 1 + 4 files changed, 28 insertions(+) diff --git a/backends/vulkan/runtime/graph/ops/glsl/activations.h b/backends/vulkan/runtime/graph/ops/glsl/activations.h index 94c9e1274de..2ba0ccc467d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/activations.h +++ b/backends/vulkan/runtime/graph/ops/glsl/activations.h @@ -42,3 +42,15 @@ vec4 hardsigmoid(vec4 tex) { hardsigmoid(tex.z), hardsigmoid(tex.w)); } + +float leaky_relu(float x, float negative_slope) { + return x * (float(x > 0.0) + negative_slope * float(x <= 0.0)); +} + +vec4 leaky_relu(vec4 tex, float negative_slope) { + return vec4( + leaky_relu(tex.x, negative_slope), + leaky_relu(tex.y, negative_slope), + leaky_relu(tex.z, negative_slope), + leaky_relu(tex.w, negative_slope)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 77a334a05eb..6757d2a6d45 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -42,3 +42,5 @@ unary_op: OPERATOR: hardswish(X) - NAME: hardsigmoid OPERATOR: hardsigmoid(X) + - NAME: leaky_relu + OPERATOR: leaky_relu(X, A) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 62922e8d9e3..4bf73fad5a1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -114,6 +114,17 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { "hardshrink"); \ } +#define DEFINE_LEAKY_RELU_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_unary_op_node( \ + graph, \ + args[0], \ + get_val_or_inf(graph, args[1], /*neg slope*/ false), \ + kDummyFloat, \ + args[2], \ + "leaky_relu"); \ + } + void gelu(ComputeGraph& graph, const std::vector& args) { // args[1] is the `approximate` string // https://fburl.com/code/9omngmyo @@ -137,6 +148,7 @@ DEFINE_RELU_FN(relu); DEFINE_HARDSHRINK_FN(hardshrink); DEFINE_ACTIVATION_FN(hardswish); DEFINE_ACTIVATION_FN(hardsigmoid); +DEFINE_LEAKY_RELU_FN(leaky_relu); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -155,6 +167,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.hardshrink.default, hardshrink); VK_REGISTER_OP(aten.hardswish.default, hardswish); VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid); + VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 9cec4891c10..2130573c0cc 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1072,6 +1072,7 @@ def get_reduce_op_inputs(): "aten.cos.default", "aten.hardswish.default", "aten.hardsigmoid.default", + "aten.leaky_relu.default", ] ) def get_unary_ops_inputs():