Skip to content

Commit 44355a6

Browse files
rjourslervpirogov
authored andcommitted
gpu: ocl: fix post-ops REPLICATE_DATA copy_size
1 parent e6b93af commit 44355a6

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/gpu/ocl/ocl_post_ops.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
219219
#define REPLICATE_DATA( \
220220
dest_ptr, dest_size, x0_s, x1_s, x2_s, x3_s, x4_s, x5_s) \
221221
{ \
222-
const unsigned copy_size \
223-
= x0_s * X_NELEMS(x1_s) * x2_s * x3_s * x4_s * x5_s; \
222+
const unsigned copy_size = x0_s * x1_s * x2_s * x3_s * x4_s * x5_s; \
224223
unroll_for(unsigned fid = copy_size; fid < dest_size; ++fid) { \
225224
*(dest_ptr + fid) = *(dest_ptr + (fid % copy_size)); \
226225
} \
@@ -263,17 +262,20 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
263262
float *bin_arg_ptr = &bin_arg[0]; \
264263
const bool use_burst_read = IS_BURSTABLE(idx, x0, x0_s, x1, x1_s, x2, \
265264
x2_s, x3, x3_s, x4, x4_s, x5, x5_s, is_burst); \
266-
const unsigned x1_jump = is_burst ? SUB_GROUP_SIZE : 1; \
267265
if (use_burst_read) { \
268266
FILL_BIN_ARG_TRY_BLOCK(idx, bin_arg_ptr, bin_arg_size, x0, x0_s, \
269267
x1, x1_s, x1_incr, x2, x2_s, x3, x3_s, x4, x4_s, x5, \
270268
x5_s); \
269+
REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, X_NELEMS(x1_s), \
270+
x2_s, x3_s, x4_s, x5_s); \
271271
} else { \
272+
const unsigned x1_jump = is_burst ? SUB_GROUP_SIZE : 1; \
273+
const unsigned x1_size = x1_s / x1_jump; \
272274
FILL_BIN_ARG_SERIAL(idx, bin_arg_ptr, x0, x0_s, (x1 + x1_incr), \
273275
x1_s, x1_jump, x2, x2_s, x3, x3_s, x4, x4_s, x5, x5_s); \
276+
REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, x1_size, x2_s, \
277+
x3_s, x4_s, x5_s); \
274278
} \
275-
REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, x1_s, x2_s, x3_s, \
276-
x4_s, x5_s); \
277279
FWD_XNARY_GENERIC_DT(PO_BINARY, CONCAT3(PO_, idx, _ALG), accumulator, \
278280
acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
279281
(sizeof(accumulator) / sizeof(acc_elem_dt)), bin_arg_ptr, \

0 commit comments

Comments
 (0)