@@ -74,18 +74,17 @@ int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha,
7474 }
7575 } else {
7676 // broadcast alpha
77- int new_alpha_size = csinn_tensor_size (input );
78- float * alpha_data_b = shl_mem_alloc (new_alpha_size * sizeof (float ));
77+ int input_size = csinn_tensor_size (input );
78+ float * alpha_data_b = shl_mem_alloc (input_size * sizeof (float ));
7979 struct csinn_tensor * alpha_ = csinn_alloc_tensor (NULL );
8080 csinn_tensor_copy (alpha_ , input );
8181 alpha_ -> data = alpha_data_b ;
8282 shl_ref_broadcast_to_shape_f32 (alpha , alpha_ , alpha_ -> dim , alpha_ -> dim_count );
8383 alpha_data = (float * )alpha_ -> data ;
8484
8585 // calculation
86- int ele_size = input -> dim [0 ] * input -> dim [1 ] * input -> dim [2 ] * input -> dim [3 ];
87- while (ele_size > 0 ) {
88- int vl = vsetvl_e32m2 (ele_size );
86+ while (input_size > 0 ) {
87+ int vl = vsetvl_e32m2 (input_size );
8988 vfloat32m2_t _input = vle32_v_f32m2 (input_data , vl );
9089 vfloat32m2_t _a = vle32_v_f32m2 (alpha_data , vl );
9190 vbool16_t _mask = vmflt_vf_f32m2_b16 (_input , 0.0f , vl );
@@ -94,20 +93,21 @@ int shl_rvv_prelu_fp32(struct csinn_tensor *input, struct csinn_tensor *alpha,
9493 input_data += vl ;
9594 alpha_data += vl ;
9695 output_data += vl ;
97- ele_size -= vl ;
98- }
99- if (output -> layout == CSINN_LAYOUT_NC1HWC0 ) {
100- const int packn = csrr_vlenb () / sizeof (float );
101- output -> dim [1 ] *= packn ;
102- output -> dim [4 ] = 0 ;
103- output -> dim_count = 4 ;
104- output -> layout = CSINN_LAYOUT_NCHW ;
96+ input_size -= vl ;
10597 }
10698
10799 // free memory and tensor
108100 shl_mem_free (alpha_data_b );
109101 csinn_free_tensor (alpha_ );
110102 }
103+
104+ if (output -> layout == CSINN_LAYOUT_NC1HWC0 ) {
105+ const int packn = csrr_vlenb () / sizeof (float );
106+ output -> dim [1 ] *= packn ;
107+ output -> dim [4 ] = 0 ;
108+ output -> dim_count = 4 ;
109+ output -> layout = CSINN_LAYOUT_NCHW ;
110+ }
111111 } else {
112112 shl_debug_error ("prelu unsupported layout: %d\n" , input -> layout );
113113 }
0 commit comments