Skip to content

Commit 87fd48f

Browse files
rjourslervpirogov
authored andcommitted
gpu: ocl: remove unecessary post_op kind
The binary and eltwise enums are guaranteed to not overlap via the API
1 parent 9a66ac6 commit 87fd48f

File tree

1 file changed

+28
-37
lines changed

1 file changed

+28
-37
lines changed

src/gpu/ocl/ocl_post_ops.h

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,23 @@
2727
#include "gpu/ocl/ocl_eltwise.h"
2828
#include "gpu/ocl/ocl_types.h"
2929

30-
float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
31-
float alpha, float beta, float scale) {
32-
if (kind == PO_BINARY) {
33-
switch (algorithm) {
34-
// binary
35-
case BINARY_ADD: return x + y; break;
36-
case BINARY_MUL: return x * y; break;
37-
case BINARY_MIN: return x < y ? x : y; break;
38-
case BINARY_MAX: return x > y ? x : y; break;
39-
case BINARY_DIV: return x / y; break;
40-
case BINARY_SUB: return x - y; break;
41-
case BINARY_GE: return x >= y; break;
42-
case BINARY_GT: return x > y; break;
43-
case BINARY_LE: return x <= y; break;
44-
case BINARY_LT: return x < y; break;
45-
case BINARY_EQ: return x == y; break;
46-
case BINARY_NE: return x != y; break;
47-
case RELU: // binary && relu = prelu
48-
return fwd_eltwise_common(RELU, x, y, beta, scale);
49-
break;
50-
default: return 0.f;
51-
}
52-
} else { // eltwise kind
53-
return fwd_eltwise_common(algorithm, x, alpha, beta, scale);
30+
float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
31+
float scale) {
32+
switch (algorithm) {
33+
// binary
34+
case BINARY_ADD: return x + y; break;
35+
case BINARY_MUL: return x * y; break;
36+
case BINARY_MIN: return x < y ? x : y; break;
37+
case BINARY_MAX: return x > y ? x : y; break;
38+
case BINARY_DIV: return x / y; break;
39+
case BINARY_SUB: return x - y; break;
40+
case BINARY_GE: return x >= y; break;
41+
case BINARY_GT: return x > y; break;
42+
case BINARY_LE: return x <= y; break;
43+
case BINARY_LT: return x < y; break;
44+
case BINARY_EQ: return x == y; break;
45+
case BINARY_NE: return x != y; break;
46+
default: return fwd_eltwise_common(algorithm, x, alpha, beta, scale);
5447
}
5548
}
5649

@@ -65,28 +58,26 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
6558
ret_val; \
6659
})
6760

68-
#define FWD_XNARY_GENERIC_DT(po_kind, algorithm, result, result_elem_dt, \
69-
arg0_ptr, arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \
61+
#define FWD_XNARY_GENERIC_DT(algorithm, result, result_elem_dt, arg0_ptr, \
62+
arg0_len, arg1_ptr, arg1_len, alpha, beta, scale) \
7063
{ \
7164
auto ty = arg0_len + arg1_len; \
7265
const typeof(ty) out_len \
7366
= max((typeof(ty))arg0_len, (typeof(ty))arg1_len); \
7467
result_elem_dt *res_ptr = (result_elem_dt *)(&result); \
7568
unroll_for(typeof(out_len + 0) idx = 0; idx < out_len; ++idx) { \
7669
if (arg0_len == 1 && arg1_len == 1) { \
77-
*res_ptr = fwd_Xnary(po_kind, algorithm, \
78-
convert_float(*arg0_ptr), convert_float(*arg1_ptr), \
79-
alpha, beta, scale); \
70+
*res_ptr = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \
71+
convert_float(*arg1_ptr), alpha, beta, scale); \
8072
} else if (arg0_len == 1) { \
81-
res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \
82-
convert_float(*arg0_ptr), \
73+
res_ptr[idx] = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \
8374
convert_float(arg1_ptr[idx]), alpha, beta, scale); \
8475
} else if (arg1_len == 1) { \
85-
res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \
86-
convert_float(arg0_ptr[idx]), \
87-
convert_float(*arg1_ptr), alpha, beta, scale); \
76+
res_ptr[idx] \
77+
= fwd_Xnary(algorithm, convert_float(arg0_ptr[idx]), \
78+
convert_float(*arg1_ptr), alpha, beta, scale); \
8879
} else { \
89-
res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \
80+
res_ptr[idx] = fwd_Xnary(algorithm, \
9081
convert_float(arg0_ptr[idx]), \
9182
convert_float(arg1_ptr[idx]), alpha, beta, scale); \
9283
} \
@@ -277,7 +268,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
277268
REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, x1_size, x2_s, \
278269
x3_s, x4_s, x5_s); \
279270
} \
280-
FWD_XNARY_GENERIC_DT(PO_BINARY, CONCAT3(PO_, idx, _ALG), accumulator, \
271+
FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
281272
acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
282273
(sizeof(accumulator) / sizeof(acc_elem_dt)), bin_arg_ptr, \
283274
bin_arg_size, 0.0f, 0.0f, 1.0f); \
@@ -292,7 +283,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
292283

293284
#define APPLY_PO_ELTWISE(idx, accumulator, acc_elem_dt) \
294285
{ \
295-
FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \
286+
FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
296287
acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
297288
(sizeof(accumulator) / sizeof(acc_elem_dt)), \
298289
((acc_elem_dt *)(&accumulator)), \

0 commit comments

Comments
 (0)