Skip to content

Commit c9f55df

Browse files
authored
Fix CPPLint issues in /math/detail/gru_kernel.h (#10390)
* Fix CPPLint issyes in gru_kernel.h * Fix CPPLint issyes in gru_kernel.h * Fix Compile error
1 parent 20fa848 commit c9f55df

File tree

3 files changed

+101
-98
lines changed

3 files changed

+101
-98
lines changed

paddle/fluid/operators/math/detail/gru_cpu_kernel.h

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
4343
r_prev_out = prev_output_value[i];
4444
}
4545

46-
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
47-
r_value_reset_output, active_gate);
46+
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
47+
&r_value_reset_output, active_gate);
4848

4949
update_gate[i] = r_value_update_gate;
5050
reset_gate[i] = r_value_reset_gate;
@@ -71,8 +71,8 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
7171
r_prev_out = prev_output_value[i];
7272
}
7373

74-
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
75-
r_output, active_node);
74+
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
75+
&r_output, active_node);
7676

7777
frame_state[i] = r_value_frame_state;
7878
output_value[i] = r_output;
@@ -99,8 +99,8 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
9999
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
100100
}
101101

102-
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
103-
r_value_reset_output, active_gate);
102+
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
103+
&r_value_reset_output, active_gate);
104104

105105
update_gate[i] = r_value_update_gate;
106106
reset_gate[i] = r_value_reset_gate;
@@ -129,8 +129,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
129129
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
130130
}
131131

132-
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
133-
r_output, active_node);
132+
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
133+
&r_output, active_node);
134134

135135
frame_state[i] = r_value_frame_state;
136136
(reinterpret_cast<__m256 *>(output_value))[i] = r_output;
@@ -213,9 +213,9 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
213213
r_prev_out_grad = prev_out_grad[i];
214214
}
215215

216-
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
217-
r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
218-
r_out_grad, active_node);
216+
op_state_grad(&r_update_gate_value, &r_update_gate_grad,
217+
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
218+
&r_prev_out_grad, &r_out_grad, active_node);
219219

220220
update_gate_grad[i] = r_update_gate_grad;
221221
frame_state_grad[i] = r_frame_state_grad;
@@ -258,9 +258,9 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
258258
r_prev_out_grad = prev_out_grad[i];
259259
}
260260

261-
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
262-
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
263-
r_reset_output_grad, active_gate);
261+
op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
262+
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
263+
&r_prev_out_grad, &r_reset_output_grad, active_gate);
264264

265265
update_gate_grad[i] = r_update_gate_grad;
266266
reset_gate_grad[i] = r_reset_gate_grad;
@@ -302,9 +302,9 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
302302
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
303303
}
304304

305-
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
306-
r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
307-
r_out_grad, active_node);
305+
op_state_grad(&r_update_gate_value, &r_update_gate_grad,
306+
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
307+
&r_prev_out_grad, &r_out_grad, active_node);
308308

309309
update_gate_grad[i] = r_update_gate_grad;
310310
frame_state_grad[i] = r_frame_state_grad;
@@ -350,9 +350,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
350350
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
351351
}
352352

353-
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
354-
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
355-
r_reset_output_grad, active_gate);
353+
op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
354+
&r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
355+
&r_prev_out_grad, &r_reset_output_grad, active_gate);
356356

357357
update_gate_grad[i] = r_update_gate_grad;
358358
reset_gate_grad[i] = r_reset_gate_grad;

paddle/fluid/operators/math/detail/gru_gpu_kernel.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
5555
r_prev_out = prev_output_value[frame_idx];
5656
}
5757

58-
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out,
59-
r_value_reset_output, active_gate);
58+
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
59+
&r_value_reset_output, active_gate);
6060

6161
gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
6262
gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
@@ -93,8 +93,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
9393
r_prev_out = prev_output_value[frame_idx];
9494
}
9595

96-
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out,
97-
r_output, active_node);
96+
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
97+
&r_output, active_node);
9898

9999
gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
100100
output_value[frame_idx] = r_output;
@@ -137,9 +137,9 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
137137
r_prev_out_grad = prev_out_grad[frame_idx];
138138
}
139139

140-
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value,
141-
r_frame_state_grad, r_prev_out_value, r_prev_out_grad,
142-
r_out_grad, active_node);
140+
op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value,
141+
&r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad,
142+
&r_out_grad, active_node);
143143

144144
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
145145
gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
@@ -185,9 +185,9 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
185185
r_reset_output_grad = reset_output_grad[frame_idx];
186186
}
187187

188-
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value,
189-
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad,
190-
r_reset_output_grad, active_gate);
188+
op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value,
189+
&r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad,
190+
&r_reset_output_grad, active_gate);
191191

192192
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
193193
gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;

paddle/fluid/operators/math/detail/gru_kernel.h

Lines changed: 71 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#pragma once
16+
#include <type_traits>
1517
#include "paddle/fluid/operators/math/detail/activation_functions.h"
1618
#include "paddle/fluid/platform/hostdevice.h"
1719

18-
#include <type_traits>
19-
2020
// TODO(guosheng): refine code style in gru_kernel
2121
namespace paddle {
2222
namespace operators {
@@ -28,25 +28,25 @@ namespace forward {
2828
template <typename T>
2929
class gru_resetOutput {
3030
public:
31-
HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate,
32-
T &prev_out, T &value_reset_output,
31+
HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate,
32+
T *prev_out, T *value_reset_output,
3333
ActivationType act_gate) {
34-
value_update_gate = activation(value_update_gate, act_gate);
35-
value_reset_gate = activation(value_reset_gate, act_gate);
36-
value_reset_output = prev_out * value_reset_gate;
34+
*value_update_gate = activation(*value_update_gate, act_gate);
35+
*value_reset_gate = activation(*value_reset_gate, act_gate);
36+
*value_reset_output = (*prev_out) * (*value_reset_gate);
3737
}
3838
#ifndef __NVCC__
3939
#ifndef __AVX__
4040
static const bool avx = false;
4141
#else
4242
static const bool avx = true;
43-
HOSTDEVICE void operator()(__m256 &value_update_gate,
44-
__m256 &value_reset_gate, __m256 &prev_out,
45-
__m256 &value_reset_output,
43+
HOSTDEVICE void operator()(__m256 *value_update_gate,
44+
__m256 *value_reset_gate, __m256 *prev_out,
45+
__m256 *value_reset_output,
4646
ActivationType act_gate) {
47-
value_update_gate = activation(value_update_gate, act_gate);
48-
value_reset_gate = activation(value_reset_gate, act_gate);
49-
value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate);
47+
*value_update_gate = activation(*value_update_gate, act_gate);
48+
*value_reset_gate = activation(*value_reset_gate, act_gate);
49+
*value_reset_output = _mm256_mul_ps(*prev_out, *value_reset_gate);
5050
}
5151
#endif
5252
#endif
@@ -55,25 +55,25 @@ class gru_resetOutput {
5555
template <typename T>
5656
class gru_finalOutput {
5757
public:
58-
HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state,
59-
T &prev_out, T &value_output,
58+
HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state,
59+
T *prev_out, T *value_output,
6060
ActivationType act_input) {
61-
value_frame_state = activation(value_frame_state, act_input);
62-
value_output = prev_out - (value_update_gate * prev_out) +
63-
(value_update_gate * value_frame_state);
61+
*value_frame_state = activation(*value_frame_state, act_input);
62+
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
63+
((*value_update_gate) * (*value_frame_state));
6464
}
6565
#ifndef __NVCC__
6666
#ifndef __AVX__
6767
static const bool avx = false;
6868
#else
6969
static const bool avx = true;
70-
HOSTDEVICE void operator()(__m256 &value_update_gate,
71-
__m256 &value_frame_state, __m256 &prev_out,
72-
__m256 &value_output, ActivationType act_input) {
73-
value_frame_state = activation(value_frame_state, act_input);
74-
value_output = _mm256_add_ps(
75-
_mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)),
76-
_mm256_mul_ps(value_update_gate, value_frame_state));
70+
HOSTDEVICE void operator()(__m256 *value_update_gate,
71+
__m256 *value_frame_state, __m256 *prev_out,
72+
__m256 *value_output, ActivationType act_input) {
73+
*value_frame_state = activation(*value_frame_state, act_input);
74+
*value_output = _mm256_add_ps(
75+
_mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)),
76+
_mm256_mul_ps(*value_update_gate, *value_frame_state));
7777
}
7878
#endif
7979
#endif
@@ -85,37 +85,38 @@ namespace backward {
8585
template <typename T>
8686
class gru_stateGrad {
8787
public:
88-
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
89-
T &value_frame_state, T &grad_frame_state,
90-
T &value_prev_out, T &grad_prev_out,
91-
T &grad_output, ActivationType act_input) {
92-
grad_update_gate = (grad_output * value_frame_state);
93-
grad_update_gate -= (grad_output * value_prev_out);
94-
grad_prev_out -= (grad_output * value_update_gate);
95-
grad_prev_out += grad_output;
96-
grad_frame_state = activation(grad_output * value_update_gate,
97-
value_frame_state, act_input);
88+
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
89+
T *value_frame_state, T *grad_frame_state,
90+
T *value_prev_out, T *grad_prev_out,
91+
T *grad_output, ActivationType act_input) {
92+
*grad_update_gate = (*grad_output * (*value_frame_state));
93+
*grad_update_gate -= (*grad_output * (*value_prev_out));
94+
*grad_prev_out -= (*grad_output * (*value_update_gate));
95+
*grad_prev_out += *grad_output;
96+
*grad_frame_state = activation(*grad_output * (*value_update_gate),
97+
*value_frame_state, act_input);
9898
}
9999
#ifndef __NVCC__
100100
#ifndef __AVX__
101101
static const bool avx = false;
102102
#else
103103
static const bool avx = true;
104-
HOSTDEVICE void operator()(__m256 &value_update_gate,
105-
__m256 &grad_update_gate,
106-
__m256 &value_frame_state,
107-
__m256 &grad_frame_state, __m256 &value_prev_out,
108-
__m256 &grad_prev_out, __m256 &grad_output,
104+
HOSTDEVICE void operator()(__m256 *value_update_gate,
105+
__m256 *grad_update_gate,
106+
__m256 *value_frame_state,
107+
__m256 *grad_frame_state, __m256 *value_prev_out,
108+
__m256 *grad_prev_out, __m256 *grad_output,
109109
ActivationType act_input) {
110-
grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state);
111-
grad_update_gate = _mm256_sub_ps(
112-
grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out));
113-
grad_prev_out = _mm256_add_ps(
114-
_mm256_sub_ps(grad_prev_out,
115-
_mm256_mul_ps(grad_output, value_update_gate)),
116-
grad_output);
117-
grad_frame_state = activation(_mm256_mul_ps(grad_output, value_update_gate),
118-
value_frame_state, act_input);
110+
*grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state);
111+
*grad_update_gate = _mm256_sub_ps(
112+
*grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out));
113+
*grad_prev_out = _mm256_add_ps(
114+
_mm256_sub_ps(*grad_prev_out,
115+
_mm256_mul_ps(*grad_output, *value_update_gate)),
116+
*grad_output);
117+
*grad_frame_state =
118+
activation(_mm256_mul_ps(*grad_output, *value_update_gate),
119+
*value_frame_state, act_input);
119120
}
120121
#endif
121122
#endif
@@ -124,32 +125,34 @@ class gru_stateGrad {
124125
template <typename T>
125126
class gru_resetGrad {
126127
public:
127-
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
128-
T &value_reset_gate, T &grad_reset_gate,
129-
T &value_prev_out, T &grad_prev_out,
130-
T &grad_reset_output, ActivationType act_gate) {
131-
grad_reset_gate = (grad_reset_output * value_prev_out);
132-
grad_prev_out += (grad_reset_output * value_reset_gate);
133-
grad_update_gate =
134-
activation(grad_update_gate, value_update_gate, act_gate);
135-
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate);
128+
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
129+
T *value_reset_gate, T *grad_reset_gate,
130+
T *value_prev_out, T *grad_prev_out,
131+
T *grad_reset_output, ActivationType act_gate) {
132+
*grad_reset_gate = (*grad_reset_output * (*value_prev_out));
133+
*grad_prev_out += (*grad_reset_output * (*value_reset_gate));
134+
*grad_update_gate =
135+
activation(*grad_update_gate, *value_update_gate, act_gate);
136+
*grad_reset_gate =
137+
activation(*grad_reset_gate, *value_reset_gate, act_gate);
136138
}
137139
#ifndef __NVCC__
138140
#ifndef __AVX__
139141
static const bool avx = false;
140142
#else
141143
static const bool avx = true;
142-
HOSTDEVICE void operator()(__m256 &value_update_gate,
143-
__m256 &grad_update_gate, __m256 &value_reset_gate,
144-
__m256 &grad_reset_gate, __m256 &value_prev_out,
145-
__m256 &grad_prev_out, __m256 &grad_reset_output,
144+
HOSTDEVICE void operator()(__m256 *value_update_gate,
145+
__m256 *grad_update_gate, __m256 *value_reset_gate,
146+
__m256 *grad_reset_gate, __m256 *value_prev_out,
147+
__m256 *grad_prev_out, __m256 *grad_reset_output,
146148
ActivationType act_gate) {
147-
grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out);
148-
grad_prev_out = _mm256_add_ps(
149-
grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate));
150-
grad_update_gate =
151-
activation(grad_update_gate, value_update_gate, act_gate);
152-
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate);
149+
*grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out);
150+
*grad_prev_out = _mm256_add_ps(
151+
*grad_prev_out, _mm256_mul_ps(*grad_reset_output, *value_reset_gate));
152+
*grad_update_gate =
153+
activation(*grad_update_gate, *value_update_gate, act_gate);
154+
*grad_reset_gate =
155+
activation(*grad_reset_gate, *value_reset_gate, act_gate);
153156
}
154157
#endif
155158
#endif

0 commit comments

Comments
 (0)