Skip to content

Commit e92c404

Browse files
committed
cpu: rnn: make activation function non-templated
1 parent 6f5621a commit e92c404

File tree

2 files changed

+31
-57
lines changed

2 files changed

+31
-57
lines changed

src/cpu/rnn/postgemm_dispatcher.hpp

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2019-2023 Intel Corporation
2+
* Copyright 2019-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -45,8 +45,8 @@ namespace dnnl {
4545
namespace impl {
4646
namespace cpu {
4747

48-
template <alg_kind_t alg_kind, prop_kind_t prop_kind>
49-
float activation(float s, float alpha, float cliping);
48+
float activation(alg_kind_t alg_kind, prop_kind_t prop_kind, float s,
49+
float alpha, float cliping);
5050

5151
template <prop_kind_t aprop, impl::data_type_t src_type,
5252
impl::data_type_t scratch_type, impl::data_type_t acc_type>
@@ -88,22 +88,6 @@ struct rnn_postgemm_dispatcher {
8888
break;
8989
case alg_kind::vanilla_rnn:
9090
postgemm_func = &class_name::rnn_postgemm;
91-
switch (pd->activation_kind()) {
92-
case alg_kind::eltwise_relu:
93-
activation_func
94-
= &activation<alg_kind::eltwise_relu, aprop>;
95-
break;
96-
case alg_kind::eltwise_tanh:
97-
activation_func
98-
= &activation<alg_kind::eltwise_tanh, aprop>;
99-
break;
100-
case alg_kind::eltwise_logistic:
101-
activation_func
102-
= &activation<alg_kind::eltwise_logistic,
103-
aprop>;
104-
break;
105-
default: assert(!"Unsupported activation function"); break;
106-
}
10791
break;
10892
case alg_kind::vanilla_gru:
10993
case alg_kind::vanilla_augru:
@@ -233,7 +217,6 @@ struct rnn_postgemm_dispatcher {
233217
}
234218

235219
protected:
236-
float (*activation_func)(float s, float alpha, float cliping);
237220
virtual rnn_postgemm_sig(rnn_postgemm) = 0;
238221
virtual rnn_postgemm_sig(lstm_postgemm) = 0;
239222
virtual rnn_postgemm_sig(lstm_projection_postgemm) = 0;

src/cpu/rnn/ref_postgemm_rnn.cpp

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2018-2023 Intel Corporation
2+
* Copyright 2018-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -31,40 +31,29 @@ using namespace dnnl::impl::utils;
3131
using namespace dnnl::impl::math;
3232
using namespace rnn_utils;
3333

34-
template <>
35-
float activation<alg_kind::eltwise_relu, prop_kind::forward>(
36-
float s, float alpha, float cliping) {
37-
return relu_fwd<float>(s, alpha);
38-
}
39-
40-
template <>
41-
float activation<alg_kind::eltwise_relu, prop_kind::backward>(
42-
float s, float alpha, float cliping) {
43-
return relu_bwd<float>(s, alpha);
44-
}
45-
46-
template <>
47-
float activation<alg_kind::eltwise_tanh, prop_kind::forward>(
48-
float s, float alpha, float cliping) {
49-
return tanh_fwd<float>(s);
50-
}
51-
52-
template <>
53-
float activation<alg_kind::eltwise_tanh, prop_kind::backward>(
54-
float s, float alpha, float cliping) {
55-
return one_m_square<float>(s);
56-
}
57-
58-
template <>
59-
float activation<alg_kind::eltwise_logistic, prop_kind::forward>(
60-
float s, float alpha, float cliping) {
61-
return logistic_fwd<float>(s);
62-
}
63-
64-
template <>
65-
float activation<alg_kind::eltwise_logistic, prop_kind::backward>(
66-
float s, float alpha, float cliping) {
67-
return x_m_square<float>(s);
34+
float activation(alg_kind_t alg_kind, prop_kind_t prop_kind, float s,
35+
float alpha, float cliping) {
36+
using namespace dnnl::impl::alg_kind;
37+
38+
if (prop_kind == prop_kind::forward
39+
|| prop_kind == prop_kind::forward_inference) {
40+
switch (alg_kind) {
41+
case eltwise_relu: return relu_fwd<float>(s, alpha);
42+
case eltwise_tanh: return tanh_fwd<float>(s);
43+
case eltwise_logistic: return logistic_fwd<float>(s);
44+
default: assert(!"unsupported algorithm");
45+
}
46+
} else if (prop_kind == prop_kind::backward) {
47+
switch (alg_kind) {
48+
case eltwise_relu: return relu_bwd<float>(s, alpha);
49+
case eltwise_tanh: return one_m_square<float>(s);
50+
case eltwise_logistic: return x_m_square<float>(s);
51+
default: assert(!"unsupported algorithm");
52+
}
53+
} else {
54+
assert(!"unsupported propagation kind");
55+
}
56+
return NAN;
6857
}
6958

7059
constexpr float linear(float s, float alpha, float clipping) {
@@ -118,7 +107,8 @@ rnn_postgemm_sig(
118107
(rnn_postgemm_fwd_t<src_type, scratch_type, acc_type>::rnn_postgemm)) {
119108
const float *scales = this->pd_->attr()->rnn_tparams_.scales_;
120109
const auto act_f = [this](float a, float alpha, float clipping) {
121-
return gates_t(this->activation_func(a, alpha, clipping));
110+
return gates_t(activation(this->pd_->activation_kind(),
111+
this->pd_->get_prop_kind(), a, alpha, clipping));
122112
};
123113
const auto linear_f = [](float a, float alpha, float clipping) {
124114
return gates_t(linear(a, alpha, clipping));
@@ -178,7 +168,8 @@ rnn_postgemm_sig(
178168
(rnn_postgemm_bwd_t<src_type, scratch_type, acc_type>::rnn_postgemm)) {
179169
const float *scales = this->pd_->attr()->rnn_tparams_.scales_;
180170
const auto act_f = [this](float a, float alpha, float clipping) {
181-
return this->activation_func(a, alpha, 0);
171+
return activation(this->pd_->activation_kind(),
172+
this->pd_->get_prop_kind(), a, alpha, 0);
182173
};
183174
const auto linear_f = [](float a, float alpha, float clipping) {
184175
return linear(a, alpha, 0);

0 commit comments

Comments
 (0)