Skip to content

Commit c8943f5

Browse files
dyoussifkarturov
authored andcommitted
Revert "gpu: ocl: remove unecessary post_op kind"
This reverts commit 00d1bac.
1 parent ad3c62f commit c8943f5

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

src/gpu/ocl/ocl_post_ops.h

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,30 @@
2727
#include "gpu/ocl/ocl_eltwise.h"
2828
#include "gpu/ocl/ocl_types.h"
2929

30-
float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
31-
float scale) {
32-
switch (algorithm) {
33-
// binary
34-
case BINARY_ADD: return x + y; break;
35-
case BINARY_MUL: return x * y; break;
36-
case BINARY_MIN: return x < y ? x : y; break;
37-
case BINARY_MAX: return x > y ? x : y; break;
38-
case BINARY_DIV: return x / y; break;
39-
case BINARY_SUB: return x - y; break;
40-
case BINARY_GE: return x >= y; break;
41-
case BINARY_GT: return x > y; break;
42-
case BINARY_LE: return x <= y; break;
43-
case BINARY_LT: return x < y; break;
44-
case BINARY_EQ: return x == y; break;
45-
case BINARY_NE: return x != y; break;
46-
default: return fwd_eltwise_common(algorithm, x, alpha, beta, scale);
30+
float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
31+
float alpha, float beta, float scale) {
32+
if (kind == PO_BINARY) {
33+
switch (algorithm) {
34+
// binary
35+
case BINARY_ADD: return x + y; break;
36+
case BINARY_MUL: return x * y; break;
37+
case BINARY_MIN: return x < y ? x : y; break;
38+
case BINARY_MAX: return x > y ? x : y; break;
39+
case BINARY_DIV: return x / y; break;
40+
case BINARY_SUB: return x - y; break;
41+
case BINARY_GE: return x >= y; break;
42+
case BINARY_GT: return x > y; break;
43+
case BINARY_LE: return x <= y; break;
44+
case BINARY_LT: return x < y; break;
45+
case BINARY_EQ: return x == y; break;
46+
case BINARY_NE: return x != y; break;
47+
case RELU: // binary && relu = prelu
48+
return fwd_eltwise_common(RELU, x, y, beta, scale);
49+
break;
50+
default: return 0.f;
51+
}
52+
} else { // eltwise kind
53+
return fwd_eltwise_common(algorithm, x, alpha, beta, scale);
4754
}
4855
}
4956

@@ -58,8 +65,8 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
5865
ret_val; \
5966
})
6067

61-
#define FWD_XNARY_GENERIC_DT(algorithm, result, result_elem_dt, arg0_ptr, \
62-
arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \
68+
#define FWD_XNARY_GENERIC_DT(po_kind, algorithm, result, result_elem_dt, \
69+
arg0_ptr, arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \
6370
{ \
6471
auto ty = arg0_len + arg1_len; \
6572
const typeof(ty) out_len \
@@ -258,7 +265,7 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
258265
REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, x1_size, x2_s, \
259266
x3_s, x4_s, x5_s); \
260267
} \
261-
FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
268+
FWD_XNARY_GENERIC_DT(PO_BINARY, CONCAT3(PO_, idx, _ALG), accumulator, \
262269
acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
263270
(sizeof(accumulator) / sizeof(acc_elem_dt)), bin_arg_ptr, \
264271
bin_arg_size, 0.0f, 0.0f, 1.0f); \
@@ -273,7 +280,7 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
273280

274281
#define APPLY_PO_ELTWISE(idx, accumulator, acc_elem_dt) \
275282
{ \
276-
FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
283+
FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \
277284
acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
278285
(sizeof(accumulator) / sizeof(acc_elem_dt)), \
279286
((acc_elem_dt *)(&accumulator)), \

0 commit comments

Comments
 (0)