2727#include "gpu/ocl/ocl_eltwise.h"
2828#include "gpu/ocl/ocl_types.h"
2929
30- float fwd_Xnary (unsigned algorithm , float x , float y , float alpha , float beta ,
31- float scale ) {
32- switch (algorithm ) {
33- // binary
34- case BINARY_ADD : return x + y ; break ;
35- case BINARY_MUL : return x * y ; break ;
36- case BINARY_MIN : return x < y ? x : y ; break ;
37- case BINARY_MAX : return x > y ? x : y ; break ;
38- case BINARY_DIV : return x / y ; break ;
39- case BINARY_SUB : return x - y ; break ;
40- case BINARY_GE : return x >= y ; break ;
41- case BINARY_GT : return x > y ; break ;
42- case BINARY_LE : return x <= y ; break ;
43- case BINARY_LT : return x < y ; break ;
44- case BINARY_EQ : return x == y ; break ;
45- case BINARY_NE : return x != y ; break ;
46- default : return fwd_eltwise_common (algorithm , x , alpha , beta , scale );
30+ float fwd_Xnary (unsigned kind , unsigned algorithm , float x , float y ,
31+ float alpha , float beta , float scale ) {
32+ if (kind == PO_BINARY ) {
33+ switch (algorithm ) {
34+ // binary
35+ case BINARY_ADD : return x + y ; break ;
36+ case BINARY_MUL : return x * y ; break ;
37+ case BINARY_MIN : return x < y ? x : y ; break ;
38+ case BINARY_MAX : return x > y ? x : y ; break ;
39+ case BINARY_DIV : return x / y ; break ;
40+ case BINARY_SUB : return x - y ; break ;
41+ case BINARY_GE : return x >= y ; break ;
42+ case BINARY_GT : return x > y ; break ;
43+ case BINARY_LE : return x <= y ; break ;
44+ case BINARY_LT : return x < y ; break ;
45+ case BINARY_EQ : return x == y ; break ;
46+ case BINARY_NE : return x != y ; break ;
47+ case RELU : // binary && relu = prelu
48+ return fwd_eltwise_common (RELU , x , y , beta , scale );
49+ break ;
50+ default : return 0.f ;
51+ }
52+ } else { // eltwise kind
53+ return fwd_eltwise_common (algorithm , x , alpha , beta , scale );
4754 }
4855}
4956
@@ -58,8 +65,8 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
5865 ret_val; \
5966 })
6067
61- #define FWD_XNARY_GENERIC_DT (algorithm , result , result_elem_dt , arg0_ptr , \
62- arg0_len , arg1_ptr , arg1_len , alpha , beta , scale ) \
68+ #define FWD_XNARY_GENERIC_DT (po_kind , algorithm , result , result_elem_dt , \
69+ arg0_ptr , arg0_len , arg1_ptr , arg1_len , alpha , beta , scale ) \
6370 { \
6471 auto ty = arg0_len + arg1_len; \
6572 const typeof(ty) out_len \
@@ -258,7 +265,7 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
258265 REPLICATE_DATA(bin_arg_ptr, bin_arg_size, x0_s, x1_size, x2_s, \
259266 x3_s, x4_s, x5_s); \
260267 } \
261- FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
268+ FWD_XNARY_GENERIC_DT(PO_BINARY, CONCAT3(PO_, idx, _ALG), accumulator, \
262269 acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
263270 (sizeof(accumulator) / sizeof(acc_elem_dt)), bin_arg_ptr, \
264271 bin_arg_size, 0.0f, 0.0f, 1.0f); \
@@ -273,7 +280,7 @@ float fwd_Xnary(unsigned algorithm, float x, float y, float alpha, float beta,
273280
274281#define APPLY_PO_ELTWISE (idx , accumulator , acc_elem_dt ) \
275282 { \
276- FWD_XNARY_GENERIC_DT(CONCAT3(PO_, idx, _ALG), accumulator, \
283+ FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \
277284 acc_elem_dt, ((acc_elem_dt *)(&accumulator)), \
278285 (sizeof(accumulator) / sizeof(acc_elem_dt)), \
279286 ((acc_elem_dt *)(&accumulator)), \
0 commit comments