Skip to content

Commit 0e36a80

Browse files
committed
[Bug Fix] add broadcast mechanism before calculating PReLu when the input layout is NCHW
[BUG FIX] avoid braodcast in certain cases [BUG FIX] reformat based on comments
1 parent dcc8155 commit 0e36a80

File tree

1 file changed

+45
-13
lines changed

1 file changed

+45
-13
lines changed

source/thead_rvv/fp32/prelu.c

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818

1919
#include "rvv/rvv.h"
20+
#include "reference/ref.h"
2021

2122
int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha,
2223
struct csinn_tensor *output, struct csinn_prelu_params *params)
@@ -53,22 +54,53 @@ int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha,
5354
output->layout = CSINN_LAYOUT_NC1HWC0;
5455
}
5556
} else if (input->layout == CSINN_LAYOUT_NCHW) {
56-
for (int n = 0; n < input->dim[0]; ++n) {
57-
for (int c = 0; c < input->dim[1]; ++c) {
58-
float a = alpha_data[c];
59-
int inner_size = input->dim[2] * input->dim[3];
60-
while (inner_size > 0) {
61-
int vl = vsetvl_e32m2(inner_size);
62-
vfloat32m2_t _input = vle32_v_f32m2(input_data, vl);
63-
vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl);
64-
vfloat32m2_t _res = vfmul_vf_f32m2_m(_mask, _input, _input, a, vl);
65-
vse32_v_f32m2(output_data, _res, vl);
66-
input_data += vl;
67-
output_data += vl;
68-
inner_size -= vl;
57+
if (alpha->dim[1] == csinn_tensor_size(alpha)) {
58+
// simplify the calculation by avoiding broadcast
59+
for (int n = 0; n < input->dim[0]; ++n) {
60+
for (int c = 0; c < input->dim[1]; ++c) {
61+
float a = alpha_data[c];
62+
int inner_size = input->dim[2] * input->dim[3];
63+
while (inner_size > 0) {
64+
int vl = vsetvl_e32m2(inner_size);
65+
vfloat32m2_t _input = vle32_v_f32m2(input_data, vl);
66+
vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl);
67+
vfloat32m2_t _res = vfmul_vf_f32m2_m(_mask, _input, _input, a, vl);
68+
vse32_v_f32m2(output_data, _res, vl);
69+
input_data += vl;
70+
output_data += vl;
71+
inner_size -= vl;
72+
}
6973
}
7074
}
75+
} else {
76+
// broadcast alpha
77+
int input_size = csinn_tensor_size(input);
78+
float *alpha_data_b = shl_mem_alloc(input_size * sizeof(float));
79+
struct csinn_tensor *alpha_ = csinn_alloc_tensor(NULL);
80+
csinn_tensor_copy(alpha_, input);
81+
alpha_->data = alpha_data_b;
82+
shl_ref_broadcast_to_shape_f32(alpha, alpha_, alpha_->dim, alpha_->dim_count);
83+
alpha_data = (float *)alpha_->data;
84+
85+
// calculation
86+
while (input_size > 0) {
87+
int vl = vsetvl_e32m2(input_size);
88+
vfloat32m2_t _input = vle32_v_f32m2(input_data, vl);
89+
vfloat32m2_t _a = vle32_v_f32m2(alpha_data, vl);
90+
vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl);
91+
vfloat32m2_t _res = vfmul_vv_f32m2_m(_mask, _input, _input, _a, vl);
92+
vse32_v_f32m2(output_data, _res, vl);
93+
input_data += vl;
94+
alpha_data += vl;
95+
output_data += vl;
96+
input_size -= vl;
97+
}
98+
99+
// free memory and tensor
100+
shl_mem_free(alpha_data_b);
101+
csinn_free_tensor(alpha_);
71102
}
103+
72104
if (output->layout == CSINN_LAYOUT_NC1HWC0) {
73105
const int packn = csrr_vlenb() / sizeof(float);
74106
output->dim[1] *= packn;

0 commit comments

Comments
 (0)