Skip to content

Commit 0184044

Browse files
committed
x64: brgemm bwd_d strided: remove static last_oc_block_size
1 parent e4737d9 commit 0184044

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

src/cpu/x64/jit_brgemm_conv_bwd_utils.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t {
455455
static constexpr int bench_iterations = 1;
456456

457457
int sp, sp_block, nb_sp;
458-
static int last_oc_block_size;
459458

460459
void get_from_jcp(const jit_brgemm_conv_conf_t &jcp) { *this = jcp; }
461460
void save_to_jcp(jit_brgemm_conv_conf_t &jcp) const { jcp = *this; }
@@ -531,7 +530,6 @@ struct brg_blocking_t : public jit_brgemm_conv_conf_t {
531530

532531
unsigned brg_blocking_t::L1;
533532
unsigned brg_blocking_t::L2;
534-
int brg_blocking_t::last_oc_block_size;
535533

536534
float brg_blocking_t::io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
537535
bool is_broadcast, bool is_shared) const {
@@ -556,7 +554,7 @@ float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr,
556554
}
557555

558556
void brg_blocking_t::select_oc_block() {
559-
const auto padded_oc = last_oc_block_size * (is_oc_padded ? acc_simd_w : 1);
557+
const auto padded_oc = vnni_block * (is_oc_padded ? acc_simd_w : 1);
560558
oc_block = (exec_type == exec_trans ? rnd_up(oc, padded_oc) : oc);
561559
nb_oc = utils::div_up(oc, oc_block);
562560
}
@@ -570,7 +568,7 @@ status_t brg_blocking_t::estimate_brgemm_ur() {
570568

571569
// Configure matrix sizes
572570
// for amx if oc_block != oc then we use exec_trans so K is oc_block
573-
const auto padded_oc = last_oc_block_size * (is_oc_padded ? acc_simd_w : 1);
571+
const auto padded_oc = vnni_block * (is_oc_padded ? acc_simd_w : 1);
574572

575573
ocp = rnd_up(oc, padded_oc);
576574

@@ -641,7 +639,7 @@ status_t brg_blocking_t::get_brgemm_ur(
641639
brg_strides.stride_a = ngroups * oc_without_padding
642640
* (dilate_w + 1) * src_dsz;
643641
//weights are padded by ic_block and last_oc_block
644-
brg_strides.stride_b = rnd_up(oc, last_oc_block_size)
642+
brg_strides.stride_b = rnd_up(oc, vnni_block)
645643
* rnd_up(ic, ic_block) * wei_dsz;
646644
const auto strides_ptr = (brg_type == brgemm_strd)
647645
? &brg_strides
@@ -1501,8 +1499,7 @@ status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
15011499

15021500
VDISPATCH_CONV_IC(!jcp.is_bf32, VERBOSE_UNSUPPORTED_DT);
15031501

1504-
brg_blocking_t::last_oc_block_size
1505-
= (jcp.wei_dt == f16 && isa == avx512_core_fp16)
1502+
jcp.vnni_block = (jcp.wei_dt == f16 && isa == avx512_core_fp16)
15061503
? 1
15071504
: data_type_vnni_granularity(jcp.wei_dt);
15081505

@@ -1921,8 +1918,7 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
19211918

19221919
jcp.copy_block_only = true;
19231920

1924-
const auto oc_padded_block
1925-
= jcp.acc_simd_w * brg_blocking_t::last_oc_block_size;
1921+
const auto oc_padded_block = jcp.acc_simd_w * jcp.vnni_block;
19261922
jcp.is_oc_padded = one_of(jcp.wei_dt, bf16, f16, s8)
19271923
&& jcp.oc > oc_padded_block && is_amx(isa);
19281924

0 commit comments

Comments
 (0)