@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15+ #pragma once
16+ #include < type_traits>
1517#include " paddle/fluid/operators/math/detail/activation_functions.h"
1618#include " paddle/fluid/platform/hostdevice.h"
1719
18- #include < type_traits>
19-
2020// TODO(guosheng): refine code style in gru_kernel
2121namespace paddle {
2222namespace operators {
@@ -28,25 +28,25 @@ namespace forward {
2828template <typename T>
2929class gru_resetOutput {
3030 public:
31- HOSTDEVICE void operator ()(T & value_update_gate, T & value_reset_gate,
32- T & prev_out, T & value_reset_output,
31+ HOSTDEVICE void operator ()(T * value_update_gate, T * value_reset_gate,
32+ T * prev_out, T * value_reset_output,
3333 ActivationType act_gate) {
34- value_update_gate = activation (value_update_gate, act_gate);
35- value_reset_gate = activation (value_reset_gate, act_gate);
36- value_reset_output = prev_out * value_reset_gate;
34+ * value_update_gate = activation (* value_update_gate, act_gate);
35+ * value_reset_gate = activation (* value_reset_gate, act_gate);
36+ * value_reset_output = (* prev_out) * (* value_reset_gate) ;
3737 }
3838#ifndef __NVCC__
3939#ifndef __AVX__
4040 static const bool avx = false ;
4141#else
4242 static const bool avx = true ;
43- HOSTDEVICE void operator ()(__m256 & value_update_gate,
44- __m256 & value_reset_gate, __m256 & prev_out,
45- __m256 & value_reset_output,
43+ HOSTDEVICE void operator ()(__m256 * value_update_gate,
44+ __m256 * value_reset_gate, __m256 * prev_out,
45+ __m256 * value_reset_output,
4646 ActivationType act_gate) {
47- value_update_gate = activation (value_update_gate, act_gate);
48- value_reset_gate = activation (value_reset_gate, act_gate);
49- value_reset_output = _mm256_mul_ps (prev_out, value_reset_gate);
47+ * value_update_gate = activation (* value_update_gate, act_gate);
48+ * value_reset_gate = activation (* value_reset_gate, act_gate);
49+ * value_reset_output = _mm256_mul_ps (* prev_out, * value_reset_gate);
5050 }
5151#endif
5252#endif
@@ -55,25 +55,25 @@ class gru_resetOutput {
5555template <typename T>
5656class gru_finalOutput {
5757 public:
58- HOSTDEVICE void operator ()(T & value_update_gate, T & value_frame_state,
59- T & prev_out, T & value_output,
58+ HOSTDEVICE void operator ()(T * value_update_gate, T * value_frame_state,
59+ T * prev_out, T * value_output,
6060 ActivationType act_input) {
61- value_frame_state = activation (value_frame_state, act_input);
62- value_output = prev_out - (value_update_gate * prev_out) +
63- ( value_update_gate * value_frame_state);
61+ * value_frame_state = activation (* value_frame_state, act_input);
62+ * value_output = * prev_out - ((* value_update_gate) * (* prev_out) ) +
63+ ((* value_update_gate) * (* value_frame_state) );
6464 }
6565#ifndef __NVCC__
6666#ifndef __AVX__
6767 static const bool avx = false ;
6868#else
6969 static const bool avx = true ;
70- HOSTDEVICE void operator ()(__m256 & value_update_gate,
71- __m256 & value_frame_state, __m256 & prev_out,
72- __m256 & value_output, ActivationType act_input) {
73- value_frame_state = activation (value_frame_state, act_input);
74- value_output = _mm256_add_ps (
75- _mm256_sub_ps (prev_out, _mm256_mul_ps (value_update_gate, prev_out)),
76- _mm256_mul_ps (value_update_gate, value_frame_state));
70+ HOSTDEVICE void operator ()(__m256 * value_update_gate,
71+ __m256 * value_frame_state, __m256 * prev_out,
72+ __m256 * value_output, ActivationType act_input) {
73+ * value_frame_state = activation (* value_frame_state, act_input);
74+ * value_output = _mm256_add_ps (
75+ _mm256_sub_ps (* prev_out, _mm256_mul_ps (* value_update_gate, * prev_out)),
76+ _mm256_mul_ps (* value_update_gate, * value_frame_state));
7777 }
7878#endif
7979#endif
@@ -85,37 +85,38 @@ namespace backward {
8585template <typename T>
8686class gru_stateGrad {
8787 public:
88- HOSTDEVICE void operator ()(T & value_update_gate, T & grad_update_gate,
89- T & value_frame_state, T & grad_frame_state,
90- T & value_prev_out, T & grad_prev_out,
91- T & grad_output, ActivationType act_input) {
92- grad_update_gate = (grad_output * value_frame_state);
93- grad_update_gate -= (grad_output * value_prev_out);
94- grad_prev_out -= (grad_output * value_update_gate);
95- grad_prev_out += grad_output;
96- grad_frame_state = activation (grad_output * value_update_gate,
97- value_frame_state, act_input);
88+ HOSTDEVICE void operator ()(T * value_update_gate, T * grad_update_gate,
89+ T * value_frame_state, T * grad_frame_state,
90+ T * value_prev_out, T * grad_prev_out,
91+ T * grad_output, ActivationType act_input) {
92+ * grad_update_gate = (* grad_output * (* value_frame_state) );
93+ * grad_update_gate -= (* grad_output * (* value_prev_out) );
94+ * grad_prev_out -= (* grad_output * (* value_update_gate) );
95+ * grad_prev_out += * grad_output;
96+ * grad_frame_state = activation (* grad_output * (* value_update_gate) ,
97+ * value_frame_state, act_input);
9898 }
9999#ifndef __NVCC__
100100#ifndef __AVX__
101101 static const bool avx = false ;
102102#else
103103 static const bool avx = true ;
104- HOSTDEVICE void operator ()(__m256 & value_update_gate,
105- __m256 & grad_update_gate,
106- __m256 & value_frame_state,
107- __m256 & grad_frame_state, __m256 & value_prev_out,
108- __m256 & grad_prev_out, __m256 & grad_output,
104+ HOSTDEVICE void operator ()(__m256 * value_update_gate,
105+ __m256 * grad_update_gate,
106+ __m256 * value_frame_state,
107+ __m256 * grad_frame_state, __m256 * value_prev_out,
108+ __m256 * grad_prev_out, __m256 * grad_output,
109109 ActivationType act_input) {
110- grad_update_gate = _mm256_mul_ps (grad_output, value_frame_state);
111- grad_update_gate = _mm256_sub_ps (
112- grad_update_gate, _mm256_mul_ps (grad_output, value_prev_out));
113- grad_prev_out = _mm256_add_ps (
114- _mm256_sub_ps (grad_prev_out,
115- _mm256_mul_ps (grad_output, value_update_gate)),
116- grad_output);
117- grad_frame_state = activation (_mm256_mul_ps (grad_output, value_update_gate),
118- value_frame_state, act_input);
110+ *grad_update_gate = _mm256_mul_ps (*grad_output, *value_frame_state);
111+ *grad_update_gate = _mm256_sub_ps (
112+ *grad_update_gate, _mm256_mul_ps (*grad_output, *value_prev_out));
113+ *grad_prev_out = _mm256_add_ps (
114+ _mm256_sub_ps (*grad_prev_out,
115+ _mm256_mul_ps (*grad_output, *value_update_gate)),
116+ *grad_output);
117+ *grad_frame_state =
118+ activation (_mm256_mul_ps (*grad_output, *value_update_gate),
119+ *value_frame_state, act_input);
119120 }
120121#endif
121122#endif
@@ -124,32 +125,34 @@ class gru_stateGrad {
124125template <typename T>
125126class gru_resetGrad {
126127 public:
127- HOSTDEVICE void operator ()(T &value_update_gate, T &grad_update_gate,
128- T &value_reset_gate, T &grad_reset_gate,
129- T &value_prev_out, T &grad_prev_out,
130- T &grad_reset_output, ActivationType act_gate) {
131- grad_reset_gate = (grad_reset_output * value_prev_out);
132- grad_prev_out += (grad_reset_output * value_reset_gate);
133- grad_update_gate =
134- activation (grad_update_gate, value_update_gate, act_gate);
135- grad_reset_gate = activation (grad_reset_gate, value_reset_gate, act_gate);
128+ HOSTDEVICE void operator ()(T *value_update_gate, T *grad_update_gate,
129+ T *value_reset_gate, T *grad_reset_gate,
130+ T *value_prev_out, T *grad_prev_out,
131+ T *grad_reset_output, ActivationType act_gate) {
132+ *grad_reset_gate = (*grad_reset_output * (*value_prev_out));
133+ *grad_prev_out += (*grad_reset_output * (*value_reset_gate));
134+ *grad_update_gate =
135+ activation (*grad_update_gate, *value_update_gate, act_gate);
136+ *grad_reset_gate =
137+ activation (*grad_reset_gate, *value_reset_gate, act_gate);
136138 }
137139#ifndef __NVCC__
138140#ifndef __AVX__
139141 static const bool avx = false ;
140142#else
141143 static const bool avx = true ;
142- HOSTDEVICE void operator ()(__m256 & value_update_gate,
143- __m256 & grad_update_gate, __m256 & value_reset_gate,
144- __m256 & grad_reset_gate, __m256 & value_prev_out,
145- __m256 & grad_prev_out, __m256 & grad_reset_output,
144+ HOSTDEVICE void operator ()(__m256 * value_update_gate,
145+ __m256 * grad_update_gate, __m256 * value_reset_gate,
146+ __m256 * grad_reset_gate, __m256 * value_prev_out,
147+ __m256 * grad_prev_out, __m256 * grad_reset_output,
146148 ActivationType act_gate) {
147- grad_reset_gate = _mm256_mul_ps (grad_reset_output, value_prev_out);
148- grad_prev_out = _mm256_add_ps (
149- grad_prev_out, _mm256_mul_ps (grad_reset_output, value_reset_gate));
150- grad_update_gate =
151- activation (grad_update_gate, value_update_gate, act_gate);
152- grad_reset_gate = activation (grad_reset_gate, value_reset_gate, act_gate);
149+ *grad_reset_gate = _mm256_mul_ps (*grad_reset_output, *value_prev_out);
150+ *grad_prev_out = _mm256_add_ps (
151+ *grad_prev_out, _mm256_mul_ps (*grad_reset_output, *value_reset_gate));
152+ *grad_update_gate =
153+ activation (*grad_update_gate, *value_update_gate, act_gate);
154+ *grad_reset_gate =
155+ activation (*grad_reset_gate, *value_reset_gate, act_gate);
153156 }
154157#endif
155158#endif
0 commit comments