|
1 | 1 | /******************************************************************************* |
2 | | -* Copyright 2018-2023 Intel Corporation |
| 2 | +* Copyright 2018-2024 Intel Corporation |
3 | 3 | * |
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | * you may not use this file except in compliance with the License. |
@@ -31,40 +31,29 @@ using namespace dnnl::impl::utils; |
31 | 31 | using namespace dnnl::impl::math; |
32 | 32 | using namespace rnn_utils; |
33 | 33 |
|
34 | | -template <> |
35 | | -float activation<alg_kind::eltwise_relu, prop_kind::forward>( |
36 | | - float s, float alpha, float cliping) { |
37 | | - return relu_fwd<float>(s, alpha); |
38 | | -} |
39 | | - |
40 | | -template <> |
41 | | -float activation<alg_kind::eltwise_relu, prop_kind::backward>( |
42 | | - float s, float alpha, float cliping) { |
43 | | - return relu_bwd<float>(s, alpha); |
44 | | -} |
45 | | - |
46 | | -template <> |
47 | | -float activation<alg_kind::eltwise_tanh, prop_kind::forward>( |
48 | | - float s, float alpha, float cliping) { |
49 | | - return tanh_fwd<float>(s); |
50 | | -} |
51 | | - |
52 | | -template <> |
53 | | -float activation<alg_kind::eltwise_tanh, prop_kind::backward>( |
54 | | - float s, float alpha, float cliping) { |
55 | | - return one_m_square<float>(s); |
56 | | -} |
57 | | - |
58 | | -template <> |
59 | | -float activation<alg_kind::eltwise_logistic, prop_kind::forward>( |
60 | | - float s, float alpha, float cliping) { |
61 | | - return logistic_fwd<float>(s); |
62 | | -} |
63 | | - |
64 | | -template <> |
65 | | -float activation<alg_kind::eltwise_logistic, prop_kind::backward>( |
66 | | - float s, float alpha, float cliping) { |
67 | | - return x_m_square<float>(s); |
| 34 | +float activation(alg_kind_t alg_kind, prop_kind_t prop_kind, float s, |
| 35 | + float alpha, float cliping) { |
| 36 | + using namespace dnnl::impl::alg_kind; |
| 37 | + |
| 38 | + if (prop_kind == prop_kind::forward |
| 39 | + || prop_kind == prop_kind::forward_inference) { |
| 40 | + switch (alg_kind) { |
| 41 | + case eltwise_relu: return relu_fwd<float>(s, alpha); |
| 42 | + case eltwise_tanh: return tanh_fwd<float>(s); |
| 43 | + case eltwise_logistic: return logistic_fwd<float>(s); |
| 44 | + default: assert(!"unsupported algorithm"); |
| 45 | + } |
| 46 | + } else if (prop_kind == prop_kind::backward) { |
| 47 | + switch (alg_kind) { |
| 48 | + case eltwise_relu: return relu_bwd<float>(s, alpha); |
| 49 | + case eltwise_tanh: return one_m_square<float>(s); |
| 50 | + case eltwise_logistic: return x_m_square<float>(s); |
| 51 | + default: assert(!"unsupported algorithm"); |
| 52 | + } |
| 53 | + } else { |
| 54 | + assert(!"unsupported propagation kind"); |
| 55 | + } |
| 56 | + return NAN; |
68 | 57 | } |
69 | 58 |
|
70 | 59 | constexpr float linear(float s, float alpha, float clipping) { |
@@ -118,7 +107,8 @@ rnn_postgemm_sig( |
118 | 107 | (rnn_postgemm_fwd_t<src_type, scratch_type, acc_type>::rnn_postgemm)) { |
119 | 108 | const float *scales = this->pd_->attr()->rnn_tparams_.scales_; |
120 | 109 | const auto act_f = [this](float a, float alpha, float clipping) { |
121 | | - return gates_t(this->activation_func(a, alpha, clipping)); |
| 110 | + return gates_t(activation(this->pd_->activation_kind(), |
| 111 | + this->pd_->get_prop_kind(), a, alpha, clipping)); |
122 | 112 | }; |
123 | 113 | const auto linear_f = [](float a, float alpha, float clipping) { |
124 | 114 | return gates_t(linear(a, alpha, clipping)); |
@@ -178,7 +168,8 @@ rnn_postgemm_sig( |
178 | 168 | (rnn_postgemm_bwd_t<src_type, scratch_type, acc_type>::rnn_postgemm)) { |
179 | 169 | const float *scales = this->pd_->attr()->rnn_tparams_.scales_; |
180 | 170 | const auto act_f = [this](float a, float alpha, float clipping) { |
181 | | - return this->activation_func(a, alpha, 0); |
| 171 | + return activation(this->pd_->activation_kind(), |
| 172 | + this->pd_->get_prop_kind(), a, alpha, 0); |
182 | 173 | }; |
183 | 174 | const auto linear_f = [](float a, float alpha, float clipping) { |
184 | 175 | return linear(a, alpha, 0); |
|
0 commit comments