@@ -130,18 +130,7 @@ status_t brgemm_convolution_bwd_strided_t<isa, is_deconv>::pd_t::init(
130130
131131 const auto adj_M = nstl::max (jcp_.M , jcp_.M_tail );
132132
133- batchsizes.resize (jcp_.max_batch + 1 );
134- for (int i = 0 ; i <= jcp_.max_batch ; i++)
135- batchsizes[i] = -1 ;
136-
137- first_bs = 0 ;
138- bs_c = 0 ;
139-
140- batchsizes[jcp_.max_batch ] = bs_c;
141- first_bs = jcp_.max_batch ;
142- bs_c++;
143-
144- brgs_sz_ = bs_c * adj_M * 2 * 2 * 2 ;
133+ brgs_sz_ = adj_M * 2 * 2 * 2 ;
145134 brgs_ = std::make_shared<brgemm_containers::brgemm_desc_container_t >();
146135 brgs_->resize (brgs_sz_);
147136
@@ -159,86 +148,79 @@ status_t brgemm_convolution_bwd_strided_t<isa, is_deconv>::pd_t::init(
159148 if (one_of (jcp_.exec_type , exec_trans, exec_vpad) && vM != jcp_.M
160149 && vM != jcp_.M_tail )
161150 continue ;
162- for (int bs = 0 ; bs <= jcp_.max_batch ; bs++) {
163- if (batchsizes[bs] == -1 ) continue ;
164- for_ (int i_init = 0 ; i_init < 2 ; i_init++)
165- for_ (int i_N = 0 ; i_N < 2 ; i_N++)
166- for (int i_K = 0 ; i_K < 2 ; i_K++) {
167- auto vbeta = (i_init) ? 0 : beta;
168- auto vN = (i_N) ? jcp_.N_tail : jcp_.N ;
169- auto vK = (i_K) ? jcp_.K_tail : jcp_.K ;
170- auto vbrgM = jcp_.use_M_mask
171- ? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail )
172- : vM;
173- auto brg_idx = get_brg_idx (bs, i, i_init, i_N, i_K);
174- // if brgemm_t already created then skip this iteration
175- if ((*brgs_)[brg_idx] != nullptr ) continue ;
176- brgemm_t brg;
177- if (vN == 0 || vK == 0 ) continue ;
178- brgemm_strides_t brg_strides;
179- brg_strides.stride_a = jcp_.brg_stride_a ;
180- brg_strides.stride_b = jcp_.brg_stride_b ;
181- brg.req_cal_comp_pads = jcp_.req_brg_comp_pad ;
182- brg.req_comp_pads_with_bcast
183- = jcp_.req_cal_comp_pad && jcp_.exec_type == exec_trans;
184- const auto strides_ptr = (jcp_.brg_type == brgemm_strd)
185- ? &brg_strides
186- : nullptr ;
187- CHECK (brgemm_desc_init (&brg, isa, jcp_.brg_type , diff_dst_type,
188- wei_type, false , false , brgemm_row_major, alpha, vbeta,
189- jcp_.LDA , jcp_.LDB , jcp_.LDC , vbrgM, vN, vK,
190- strides_ptr));
191-
192- brgemm_attr_t brgattr;
193- brgattr.use_uker = jcp_.use_uker ;
194- brgattr.use_interleave_stores = jcp_.use_interleave_stores ;
195- brgattr.hint_prefetching = jcp_.hint_prefetching ;
196- brgattr.max_bs = bs;
197- brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
198- ? brgemm_bd_loop_innermost
199- : brgemm_ld_loop_innermost;
200- if (jcp_.amx_tile_load_xx ) {
201- // assuming 2x2 decomposition in amx brgemm kernel
202- // and overlap of input by kw
203- const auto bd_blocking = 2 * jcp_.amx_h ;
204- const auto ld_blocking = 2 * 16 ;
205- brgattr.hint_expected_A_size = bd_blocking * jcp_.K
206- * jcp_.kd_block * jcp_.kh_block ;
207- brgattr.hint_expected_B_size = ld_blocking * jcp_.K
208- * jcp_.kd_block * jcp_.kh_block * jcp_.kw_block ;
209- brgattr.hint_expected_C_size = bd_blocking * ld_blocking;
210- } else {
211- brgattr.hint_expected_A_size = 0 ;
212- brgattr.hint_expected_B_size = 0 ;
213- brgattr.hint_expected_C_size = 0 ;
214- }
151+ for_ (int i_init = 0 ; i_init < 2 ; i_init++)
152+ for_ (int i_N = 0 ; i_N < 2 ; i_N++)
153+ for (int i_K = 0 ; i_K < 2 ; i_K++) {
154+ auto vbeta = (i_init) ? 0 : beta;
155+ auto vN = (i_N) ? jcp_.N_tail : jcp_.N ;
156+ auto vK = (i_K) ? jcp_.K_tail : jcp_.K ;
157+ auto vbrgM = jcp_.use_M_mask
158+ ? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail )
159+ : vM;
160+ auto brg_idx = get_brg_idx (jcp_.max_batch , i, i_init, i_N, i_K);
161+ // if brgemm_t already created then skip this iteration
162+ if ((*brgs_)[brg_idx] != nullptr ) continue ;
163+ brgemm_t brg;
164+ if (vN == 0 || vK == 0 ) continue ;
165+ brgemm_strides_t brg_strides;
166+ brg_strides.stride_a = jcp_.brg_stride_a ;
167+ brg_strides.stride_b = jcp_.brg_stride_b ;
168+ brg.req_cal_comp_pads = jcp_.req_brg_comp_pad ;
169+ brg.req_comp_pads_with_bcast
170+ = jcp_.req_cal_comp_pad && jcp_.exec_type == exec_trans;
171+ const auto strides_ptr
172+ = (jcp_.brg_type == brgemm_strd) ? &brg_strides : nullptr ;
173+ CHECK (brgemm_desc_init (&brg, isa, jcp_.brg_type , diff_dst_type,
174+ wei_type, false , false , brgemm_row_major, alpha, vbeta,
175+ jcp_.LDA , jcp_.LDB , jcp_.LDC , vbrgM, vN, vK, strides_ptr));
176+
177+ brgemm_attr_t brgattr;
178+ brgattr.use_uker = jcp_.use_uker ;
179+ brgattr.use_interleave_stores = jcp_.use_interleave_stores ;
180+ brgattr.hint_prefetching = jcp_.hint_prefetching ;
181+ brgattr.max_bs = jcp_.max_batch ;
182+ brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
183+ ? brgemm_bd_loop_innermost
184+ : brgemm_ld_loop_innermost;
185+ if (jcp_.amx_tile_load_xx ) {
186+ // assuming 2x2 decomposition in amx brgemm kernel
187+ // and overlap of input by kw
188+ const auto bd_blocking = 2 * jcp_.amx_h ;
189+ const auto ld_blocking = 2 * 16 ;
190+ brgattr.hint_expected_A_size
191+ = bd_blocking * jcp_.K * jcp_.kd_block * jcp_.kh_block ;
192+ brgattr.hint_expected_B_size = ld_blocking * jcp_.K
193+ * jcp_.kd_block * jcp_.kh_block * jcp_.kw_block ;
194+ brgattr.hint_expected_C_size = bd_blocking * ld_blocking;
195+ } else {
196+ brgattr.hint_expected_A_size = 0 ;
197+ brgattr.hint_expected_B_size = 0 ;
198+ brgattr.hint_expected_C_size = 0 ;
199+ }
215200
216- brgattr.wary_tail_read = false ;
217- // use_M_mask is always 0 for brgemm_convolution_bwd_strided_t
218- brgattr.bd_mask = nullptr ;
219- brgattr.bd_mask_level = jcp_.use_M_mask ;
220-
221- if (is_amx) {
222- brgattr.max_top_vpad = 0 ;
223- brgattr.max_bottom_vpad = 0 ;
224- } else {
225- brgattr.max_top_vpad = jcp_.max_vpad ;
226- brgattr.max_bottom_vpad = jcp_.max_vpad ;
227- }
228- brgattr.generate_skip_accumulation = true ;
229- CHECK (brgemm_desc_set_attr (&brg, brgattr));
230-
231- auto LDD = jcp_.stride_w * jcp_.ic_without_padding ;
232- brg.with_sum = with_sum;
233- brg.with_weights_scale_adjust
234- = jcp_.scale_adjust_factor != 1 .0f ;
235- CHECK (brgemm_desc_set_postops (
236- &brg, attr (), &diff_src_md_, LDD, jcp_.bia_dt ));
237- jcp_.amx_buf_size_per_thread
238- = nstl::max (brg.get_wsp_buffer_size (),
239- jcp_.amx_buf_size_per_thread );
240- brgs_->insert (brg_idx, brg);
201+ brgattr.wary_tail_read = false ;
202+ // use_M_mask is always 0 for brgemm_convolution_bwd_strided_t
203+ brgattr.bd_mask = nullptr ;
204+ brgattr.bd_mask_level = jcp_.use_M_mask ;
205+
206+ if (is_amx) {
207+ brgattr.max_top_vpad = 0 ;
208+ brgattr.max_bottom_vpad = 0 ;
209+ } else {
210+ brgattr.max_top_vpad = jcp_.max_vpad ;
211+ brgattr.max_bottom_vpad = jcp_.max_vpad ;
241212 }
213+ brgattr.generate_skip_accumulation = true ;
214+ CHECK (brgemm_desc_set_attr (&brg, brgattr));
215+
216+ auto LDD = jcp_.stride_w * jcp_.ic_without_padding ;
217+ brg.with_sum = with_sum;
218+ brg.with_weights_scale_adjust = jcp_.scale_adjust_factor != 1 .0f ;
219+ CHECK (brgemm_desc_set_postops (
220+ &brg, attr (), &diff_src_md_, LDD, jcp_.bia_dt ));
221+ jcp_.amx_buf_size_per_thread = nstl::max (
222+ brg.get_wsp_buffer_size (), jcp_.amx_buf_size_per_thread );
223+ brgs_->insert (brg_idx, brg);
242224 }
243225 }
244226
@@ -410,17 +392,13 @@ void brgemm_convolution_bwd_strided_t<isa, is_deconv>::create_kernels() {
410392 : 0 ;
411393 int i_init_end = 2 ;
412394
413- for (int bs = 0 ; bs <= jcp.max_batch ; bs++) {
414- if (_pd->batchsizes [bs] == -1 ) continue ;
415-
416- for_ (int i_N = N_begin; i_N < N_end; i_N++)
417- for_ (int i_M = M_begin; i_M < M_end; i_M++)
418- for_ (int i_init = i_init_begin; i_init < i_init_end; i_init++)
419- for (int i_K = K_begin; i_K < K_end; i_K++) {
420- auto M = (i_M) ? jcp.M_tail : jcp.M ;
421- if (M <= 0 ) continue ;
422- add_brg_kernel (bs, M, i_N, i_K, i_init);
423- }
395+ for_ (int i_N = N_begin; i_N < N_end; i_N++)
396+ for_ (int i_M = M_begin; i_M < M_end; i_M++)
397+ for_ (int i_init = i_init_begin; i_init < i_init_end; i_init++)
398+ for (int i_K = K_begin; i_K < K_end; i_K++) {
399+ auto M = (i_M) ? jcp.M_tail : jcp.M ;
400+ if (M <= 0 ) continue ;
401+ add_brg_kernel (jcp.max_batch , M, i_N, i_K, i_init);
424402 }
425403
426404 if (jcp.exec_type == exec_base) {
@@ -444,14 +422,11 @@ void brgemm_convolution_bwd_strided_t<isa, is_deconv>::create_kernels() {
444422 for (int kw = kw_s; kw < kw_f; kw++) {
445423 get_iw_range (iw_str, iw, kw, iw_s, M_without_overflow);
446424 if (M_without_overflow <= 0 ) continue ;
447- for (int bs = 0 ; bs <= jcp.max_batch ; bs++) {
448- if (_pd->batchsizes [bs] == -1 ) continue ;
449- for_ (int i_init = 0 ; i_init < 2 ; i_init++)
450- for_ (int i_N = 0 ; i_N < 2 ; i_N++)
451- for (int i_K = 0 ; i_K < 2 ; i_K++) {
452- add_brg_kernel (
453- bs, M_without_overflow, i_N, i_K, i_init);
454- }
425+ for_ (int i_init = 0 ; i_init < 2 ; i_init++)
426+ for_ (int i_N = 0 ; i_N < 2 ; i_N++)
427+ for (int i_K = 0 ; i_K < 2 ; i_K++) {
428+ add_brg_kernel (jcp.max_batch , M_without_overflow, i_N, i_K,
429+ i_init);
455430 }
456431
457432 bool is_iw_tail = (jcp.iw - iw < jcp.iw_block );
0 commit comments