@@ -455,7 +455,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t {
455455 static constexpr int bench_iterations = 1 ;
456456
457457 int sp, sp_block, nb_sp;
458- static int last_oc_block_size;
459458
460459 void get_from_jcp (const jit_brgemm_conv_conf_t &jcp) { *this = jcp; }
461460 void save_to_jcp (jit_brgemm_conv_conf_t &jcp) const { jcp = *this ; }
@@ -531,7 +530,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t {
531530
532531unsigned brg_blocking_t ::L1;
533532unsigned brg_blocking_t ::L2;
534- int brg_blocking_t ::last_oc_block_size;
535533
536534float brg_blocking_t::io_k (dim_t src, dim_t wei, dim_t dst, float n, float pk,
537535 bool is_broadcast, bool is_shared) const {
@@ -556,7 +554,7 @@ float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr,
556554}
557555
558556void brg_blocking_t::select_oc_block () {
559- const auto padded_oc = last_oc_block_size * (is_oc_padded ? acc_simd_w : 1 );
557+ const auto padded_oc = vnni_block * (is_oc_padded ? acc_simd_w : 1 );
560558 oc_block = (exec_type == exec_trans ? rnd_up (oc, padded_oc) : oc);
561559 nb_oc = utils::div_up (oc, oc_block);
562560}
@@ -570,7 +568,7 @@ status_t brg_blocking_t::estimate_brgemm_ur() {
570568
571569 // Configure matrix sizes
572570 // for amx if oc_block != oc then we use exec_trans so K is oc_block
573- const auto padded_oc = last_oc_block_size * (is_oc_padded ? acc_simd_w : 1 );
571+ const auto padded_oc = vnni_block * (is_oc_padded ? acc_simd_w : 1 );
574572
575573 ocp = rnd_up (oc, padded_oc);
576574
@@ -641,7 +639,7 @@ status_t brg_blocking_t::get_brgemm_ur(
641639 brg_strides.stride_a = ngroups * oc_without_padding
642640 * (dilate_w + 1 ) * src_dsz;
643641 // weights are padded by ic_block and last_oc_block
644- brg_strides.stride_b = rnd_up (oc, last_oc_block_size )
642+ brg_strides.stride_b = rnd_up (oc, vnni_block )
645643 * rnd_up (ic, ic_block) * wei_dsz;
646644 const auto strides_ptr = (brg_type == brgemm_strd)
647645 ? &brg_strides
@@ -1501,8 +1499,7 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
15011499
15021500 VDISPATCH_CONV_IC (!jcp.is_bf32 , VERBOSE_UNSUPPORTED_DT);
15031501
1504- brg_blocking_t ::last_oc_block_size
1505- = (jcp.wei_dt == f16 && isa == avx512_core_fp16)
1502+ jcp.vnni_block = (jcp.wei_dt == f16 && isa == avx512_core_fp16)
15061503 ? 1
15071504 : data_type_vnni_granularity (jcp.wei_dt );
15081505
@@ -1921,8 +1918,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
19211918
19221919 jcp.copy_block_only = true ;
19231920
1924- const auto oc_padded_block
1925- = jcp.acc_simd_w * brg_blocking_t ::last_oc_block_size;
1921+ const auto oc_padded_block = jcp.acc_simd_w * jcp.vnni_block ;
19261922 jcp.is_oc_padded = one_of (jcp.wei_dt , bf16 , f16 , s8)
19271923 && jcp.oc > oc_padded_block && is_amx (isa);
19281924
0 commit comments