Skip to content

Commit f96798e

Browse files
committed
[BUG FIX] reformat based on comments
1 parent 2a9f2a6 commit f96798e

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

source/thead_rvv/fp32/prelu.c

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,17 @@ int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha,
7474
}
7575
} else {
7676
// broadcast alpha
77-
int new_alpha_size = csinn_tensor_size(input);
78-
float *alpha_data_b = shl_mem_alloc(new_alpha_size * sizeof(float));
77+
int input_size = csinn_tensor_size(input);
78+
float *alpha_data_b = shl_mem_alloc(input_size * sizeof(float));
7979
struct csinn_tensor *alpha_ = csinn_alloc_tensor(NULL);
8080
csinn_tensor_copy(alpha_, input);
8181
alpha_->data = alpha_data_b;
8282
shl_ref_broadcast_to_shape_f32(alpha, alpha_, alpha_->dim, alpha_->dim_count);
8383
alpha_data = (float *)alpha_->data;
8484

8585
// calculation
86-
int ele_size = input->dim[0] * input->dim[1] * input->dim[2] * input->dim[3];
87-
while (ele_size > 0) {
88-
int vl = vsetvl_e32m2(ele_size);
86+
while (input_size > 0) {
87+
int vl = vsetvl_e32m2(input_size);
8988
vfloat32m2_t _input = vle32_v_f32m2(input_data, vl);
9089
vfloat32m2_t _a = vle32_v_f32m2(alpha_data, vl);
9190
vbool16_t _mask = vmflt_vf_f32m2_b16(_input, 0.0f, vl);
@@ -94,20 +93,21 @@ int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha,
9493
input_data += vl;
9594
alpha_data += vl;
9695
output_data += vl;
97-
ele_size -= vl;
98-
}
99-
if (output->layout == CSINN_LAYOUT_NC1HWC0) {
100-
const int packn = csrr_vlenb() / sizeof(float);
101-
output->dim[1] *= packn;
102-
output->dim[4] = 0;
103-
output->dim_count = 4;
104-
output->layout = CSINN_LAYOUT_NCHW;
96+
input_size -= vl;
10597
}
10698

10799
// free memory and tensor
108100
shl_mem_free(alpha_data_b);
109101
csinn_free_tensor(alpha_);
110102
}
103+
104+
if (output->layout == CSINN_LAYOUT_NC1HWC0) {
105+
const int packn = csrr_vlenb() / sizeof(float);
106+
output->dim[1] *= packn;
107+
output->dim[4] = 0;
108+
output->dim_count = 4;
109+
output->layout = CSINN_LAYOUT_NCHW;
110+
}
111111
} else {
112112
shl_debug_error("prelu unsupported layout: %d\n", input->layout);
113113
}

0 commit comments

Comments
 (0)