|
17 | 17 | */ |
18 | 18 |
|
19 | 19 | #include "rvv/rvv.h" |
| 20 | +#include "reference/ref.h" |
20 | 21 |
|
21 | 22 | int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha, |
22 | 23 | struct csinn_tensor *output, struct csinn_prelu_params *params) |
@@ -53,22 +54,53 @@ int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha, |
53 | 54 | output->layout = CSINN_LAYOUT_NC1HWC0; |
54 | 55 | } |
55 | 56 | } else if (input->layout == CSINN_LAYOUT_NCHW) { |
56 | | - for (int n = 0; n < input->dim[0]; ++n) { |
57 | | - for (int c = 0; c < input->dim[1]; ++c) { |
58 | | - float a = alpha_data[c]; |
59 | | - int inner_size = input->dim[2] * input->dim[3]; |
60 | | - while (inner_size > 0) { |
61 | | - int vl = vsetvl_e32m2(inner_size); |
62 | | - vfloat32m2_t _input = vle32_v_f32m2(input_data, vl); |
63 | | - vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl); |
64 | | - vfloat32m2_t _res = vfmul_vf_f32m2_m(_mask, _input, _input, a, vl); |
65 | | - vse32_v_f32m2(output_data, _res, vl); |
66 | | - input_data += vl; |
67 | | - output_data += vl; |
68 | | - inner_size -= vl; |
| 57 | + if (alpha->dim[1] == csinn_tensor_size(alpha)) { |
| 58 | + // simplify the calculation by avoiding broadcast |
| 59 | + for (int n = 0; n < input->dim[0]; ++n) { |
| 60 | + for (int c = 0; c < input->dim[1]; ++c) { |
| 61 | + float a = alpha_data[c]; |
| 62 | + int inner_size = input->dim[2] * input->dim[3]; |
| 63 | + while (inner_size > 0) { |
| 64 | + int vl = vsetvl_e32m2(inner_size); |
| 65 | + vfloat32m2_t _input = vle32_v_f32m2(input_data, vl); |
| 66 | + vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl); |
| 67 | + vfloat32m2_t _res = vfmul_vf_f32m2_m(_mask, _input, _input, a, vl); |
| 68 | + vse32_v_f32m2(output_data, _res, vl); |
| 69 | + input_data += vl; |
| 70 | + output_data += vl; |
| 71 | + inner_size -= vl; |
| 72 | + } |
69 | 73 | } |
70 | 74 | } |
| 75 | + } else { |
| 76 | + // broadcast alpha |
| 77 | + int input_size = csinn_tensor_size(input); |
| 78 | + float *alpha_data_b = shl_mem_alloc(input_size * sizeof(float)); |
| 79 | + struct csinn_tensor *alpha_ = csinn_alloc_tensor(NULL); |
| 80 | + csinn_tensor_copy(alpha_, input); |
| 81 | + alpha_->data = alpha_data_b; |
| 82 | + shl_ref_broadcast_to_shape_f32(alpha, alpha_, alpha_->dim, alpha_->dim_count); |
| 83 | + alpha_data = (float *)alpha_->data; |
| 84 | + |
| 85 | + // calculation |
| 86 | + while (input_size > 0) { |
| 87 | + int vl = vsetvl_e32m2(input_size); |
| 88 | + vfloat32m2_t _input = vle32_v_f32m2(input_data, vl); |
| 89 | + vfloat32m2_t _a = vle32_v_f32m2(alpha_data, vl); |
| 90 | + vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl); |
| 91 | + vfloat32m2_t _res = vfmul_vv_f32m2_m(_mask, _input, _input, _a, vl); |
| 92 | + vse32_v_f32m2(output_data, _res, vl); |
| 93 | + input_data += vl; |
| 94 | + alpha_data += vl; |
| 95 | + output_data += vl; |
| 96 | + input_size -= vl; |
| 97 | + } |
| 98 | + |
| 99 | + // free memory and tensor |
| 100 | + shl_mem_free(alpha_data_b); |
| 101 | + csinn_free_tensor(alpha_); |
71 | 102 | } |
| 103 | + |
72 | 104 | if (output->layout == CSINN_LAYOUT_NC1HWC0) { |
73 | 105 | const int packn = csrr_vlenb() / sizeof(float); |
74 | 106 | output->dim[1] *= packn; |
|
0 commit comments