2727#include "gpu/ocl/ocl_eltwise.h"
2828#include "gpu/ocl/ocl_types.h"
2929
30- float fwd_Xnary (unsigned kind , unsigned algorithm , float x , float y ,
31- float alpha , float beta , float scale ) {
32- if (kind == PO_BINARY ) {
33- switch (algorithm ) {
34- // binary
35- case BINARY_ADD : return x + y ; break ;
36- case BINARY_MUL : return x * y ; break ;
37- case BINARY_MIN : return x < y ? x : y ; break ;
38- case BINARY_MAX : return x > y ? x : y ; break ;
39- case BINARY_DIV : return x / y ; break ;
40- case BINARY_SUB : return x - y ; break ;
41- case BINARY_GE : return x >= y ; break ;
42- case BINARY_GT : return x > y ; break ;
43- case BINARY_LE : return x <= y ; break ;
44- case BINARY_LT : return x < y ; break ;
45- case BINARY_EQ : return x == y ; break ;
46- case BINARY_NE : return x != y ; break ;
47- case RELU : // binary && relu = prelu
48- return fwd_eltwise_common (RELU , x , y , beta , scale );
49- break ;
50- default : return 0.f ;
51- }
52- } else { // eltwise kind
53- return fwd_eltwise_common (algorithm , x , alpha , beta , scale );
30+ float fwd_Xnary (unsigned algorithm , float x , float y , float alpha , float beta ,
31+ float scale ) {
32+ switch (algorithm ) {
33+ // binary
34+ case BINARY_ADD : return x + y ; break ;
35+ case BINARY_MUL : return x * y ; break ;
36+ case BINARY_MIN : return x < y ? x : y ; break ;
37+ case BINARY_MAX : return x > y ? x : y ; break ;
38+ case BINARY_DIV : return x / y ; break ;
39+ case BINARY_SUB : return x - y ; break ;
40+ case BINARY_GE : return x >= y ; break ;
41+ case BINARY_GT : return x > y ; break ;
42+ case BINARY_LE : return x <= y ; break ;
43+ case BINARY_LT : return x < y ; break ;
44+ case BINARY_EQ : return x == y ; break ;
45+ case BINARY_NE : return x != y ; break ;
46+ default : return fwd_eltwise_common (algorithm , x , alpha , beta , scale );
5447 }
5548}
5649
@@ -65,28 +58,26 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
6558 ret_val; \
6659 })
6760
68- #define FWD_XNARY_GENERIC_DT (po_kind , algorithm , result , result_elem_dt , \
69- arg0_ptr , arg0_len , arg1_ptr , arg1_len , alpha , beta , scale ) \
61+ #define FWD_XNARY_GENERIC_DT (algorithm , result , result_elem_dt , arg0_ptr , \
62+ arg0_len , arg1_ptr , arg1_len , alpha , beta , scale ) \
7063 { \
7164 auto ty = arg0_len + arg1_len; \
7265 const typeof(ty) out_len \
7366 = max((typeof(ty))arg0_len, (typeof(ty))arg1_len); \
7467 result_elem_dt *res_ptr = (result_elem_dt *)(&result); \
7568 unroll_for(typeof(out_len + 0) idx = 0; idx < out_len; ++idx) { \
7669 if (arg0_len == 1 && arg1_len == 1) { \
77- *res_ptr = fwd_Xnary(po_kind, algorithm, \
78- convert_float(*arg0_ptr), convert_float(*arg1_ptr), \
79- alpha, beta, scale); \
70+ *res_ptr = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \
71+ convert_float(*arg1_ptr), alpha, beta, scale); \
8072 } else if (arg0_len == 1) { \
81- res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \
82- convert_float(*arg0_ptr), \
73+ res_ptr[idx] = fwd_Xnary(algorithm, convert_float(*arg0_ptr), \
8374 convert_float(arg1_ptr[idx]), alpha, beta, scale); \
8475 } else if (arg1_len == 1) { \
85- res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \
86- convert_float(arg0_ptr[idx]), \
87- convert_float(*arg1_ptr), alpha, beta, scale); \
76+ res_ptr[idx] \
77+ = fwd_Xnary(algorithm, convert_float(arg0_ptr[idx]), \
78+ convert_float(*arg1_ptr), alpha, beta, scale); \
8879 } else { \
89- res_ptr[idx] = fwd_Xnary(po_kind, algorithm, \
80+ res_ptr[idx] = fwd_Xnary(algorithm, \
9081 convert_float(arg0_ptr[idx]), \
9182 convert_float(arg1_ptr[idx]), alpha, beta, scale); \
9283 } \
@@ -277,7 +268,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
277268 REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, x1_size, x2_s, \
278269 x3_s, x4_s, x5_s); \
279270 } \
280- FWD_XNARY_GENERIC_DT(PO_BINARY, CONCAT3(PO_, idx, _ALG), accumulator, \
271+ FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
281272 acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
282273 (sizeof(accumulator) / sizeof(acc_elem_dt)), bin_arg_ptr, \
283274 bin_arg_size, 0.0f, 0.0f, 1.0f); \
@@ -292,7 +283,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
292283
293284#define APPLY_PO_ELTWISE (idx , accumulator , acc_elem_dt ) \
294285 { \
295- FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \
286+ FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
296287 acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
297288 (sizeof(accumulator) / sizeof(acc_elem_dt)), \
298289 ((acc_elem_dt *)(&accumulator)), \
0 commit comments