@@ -40,6 +40,22 @@ using namespace nstl;
4040
4141using namespace data_type ;
4242
43+ namespace {
44+ void maybe_tile_configure (bool is_amx,
45+ const char brg_kernel_palettes[][AMX_PALETTE_SIZE], int brg_ker_idx,
46+ int &prev_ker_idx) {
47+ if (!is_amx) return ;
48+ if (brg_ker_idx == prev_ker_idx) return ;
49+ // TODO: more accurately estimate the costs of memcmp and tile configuration
50+ if (prev_ker_idx == -1
51+ || std::memcmp (&brg_kernel_palettes[brg_ker_idx][0 ],
52+ &brg_kernel_palettes[prev_ker_idx][0 ], AMX_PALETTE_SIZE)
53+ != 0 )
54+ amx_tile_configure (&brg_kernel_palettes[brg_ker_idx][0 ]);
55+ prev_ker_idx = brg_ker_idx;
56+ }
57+ } // namespace
58+
4359template <cpu_isa_t isa>
4460status_t brgemm_matmul_t <isa>::pd_t ::init(engine_t *engine) {
4561 const auto src_dt = src_md_.data_type ;
@@ -119,8 +135,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
119135 brgemm_attr_t brgattr;
120136 brgattr.generate_skip_accumulation
121137 = bgmmc_.post_ops_applicable && bgmmc_.nthr_k > 1 ;
122- constexpr bool is_amx = one_of (
123- isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
138+ const bool is_amx = is_superset (isa, avx512_core_amx);
124139 if (is_amx) {
125140 if (!brgattr.generate_skip_accumulation ) {
126141 // TODO: uker doesn't yet support generate_skip_accumulation
@@ -217,10 +232,9 @@ status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
217232 balance211 ((int )bgmmc.K_chunks , brgmm_ctx.get_num_threads_for_k (),
218233 ithr_k, kc_start, kc_end);
219234
220- if (is_amx) {
221- const auto base_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx ();
222- amx_tile_configure (&brg_kernel_palettes_[base_ker_idx][0 ]);
223- }
235+ int prev_ker_idx = -1 ;
236+ maybe_tile_configure (is_amx, brg_kernel_palettes_,
237+ brgmm_ctx.get_base_brgemm_kernel_idx (), prev_ker_idx);
224238
225239 int b {0 }, mc {0 }, nc {0 };
226240 nd_iterator_init (
@@ -239,8 +253,8 @@ status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
239253 for (int mb = m_start; mb < m_end; mb++) {
240254 if (use_buffer_a && nb == n_start)
241255 copy_a_chunk_in_buffer (brgmm_ctx, ithr, b, mb, kc);
242- compute_kernel (
243- brgmm_ctx, ithr, b, mb, nb, kc, kc == kc_start);
256+ compute_kernel (brgmm_ctx, ithr, b, mb, nb, kc,
257+ kc == kc_start, prev_ker_idx );
244258 }
245259 }
246260 ++start;
@@ -258,12 +272,12 @@ status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
258272template <cpu_isa_t isa>
259273void brgemm_matmul_t <isa>::compute_kernel(
260274 const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
261- int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init) const {
275+ int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init,
276+ int &prev_ker_idx) const {
262277 constexpr bool is_amx
263278 = one_of (isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
264279 const auto &bgmmc = pd ()->get_brgemm_matmul_conf ();
265280 const auto addr_batch = brgmm_ctx.get_batch_elem_ptr (ithr);
266- const int base_brg_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx ();
267281
268282 const auto wsp_tile = brgmm_ctx.get_tile_workspace (ithr);
269283 const int m = m_blk_idx * bgmmc.M_blk ;
@@ -302,10 +316,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
302316 if (gemm_batch > 0 && brg_ker_idx >= 0 ) {
303317 const auto brg_kernel = brg_kernels_[brg_ker_idx].get ();
304318 assert (brg_kernel != nullptr );
305-
306- const bool is_tile_reconf_required = is_amx && (is_M_tail || is_N_tail);
307- if (is_tile_reconf_required)
308- amx_tile_configure (&brg_kernel_palettes_[brg_ker_idx][0 ]);
319+ maybe_tile_configure (
320+ is_amx, brg_kernel_palettes_, brg_ker_idx, prev_ker_idx);
309321
310322 brgmm_ctx.init_brgemm_batch_elements_values (
311323 ithr, 0 , gemm_batch, b_idx, m_blk_idx, k_blk_idx, n_blk_idx);
@@ -339,9 +351,6 @@ void brgemm_matmul_t<isa>::compute_kernel(
339351 brgemm_kernel_execute (brg_kernel, gemm_batch, addr_batch,
340352 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr );
341353 }
342-
343- if (is_tile_reconf_required)
344- amx_tile_configure (&brg_kernel_palettes_[base_brg_ker_idx][0 ]);
345354 }
346355 if (is_K_tail) {
347356 brgmm_ctx.init_brgemm_batch_elements_values (
@@ -350,11 +359,10 @@ void brgemm_matmul_t<isa>::compute_kernel(
350359 const bool use_init_ker = (do_init && gemm_batch == 0 );
351360 const int brg_ker_idx = pd ()->get_brg_kernel_idx (
352361 false , use_init_ker, is_M_tail, is_N_tail, true );
362+ maybe_tile_configure (
363+ is_amx, brg_kernel_palettes_, brg_ker_idx, prev_ker_idx);
353364 const auto brg_kernel_k_tail = brg_kernels_[brg_ker_idx].get ();
354- const bool is_tile_reconf_required
355- = is_amx && bgmmc.K_tail != bgmmc.K_blk ;
356- if (is_tile_reconf_required)
357- amx_tile_configure (&brg_kernel_palettes_[brg_ker_idx][0 ]);
365+
358366 if (post_ops_applicable) {
359367 void *scratch = is_amx
360368 ? static_cast <void *>(wsp_tile)
@@ -384,8 +392,6 @@ void brgemm_matmul_t<isa>::compute_kernel(
384392 brgemm_kernel_execute (brg_kernel_k_tail, 1 , addr_batch,
385393 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr );
386394 }
387- if (is_tile_reconf_required)
388- amx_tile_configure (&brg_kernel_palettes_[base_brg_ker_idx][0 ]);
389395 }
390396}
391397
@@ -394,6 +400,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
394400 const brg_matmul_exec_ctx_t &brgmm_ctx) const {
395401 if (!brgmm_ctx.parallel_reduction_is_used ()) return ;
396402
403+ const bool is_amx = is_superset (isa, avx512_core_amx);
404+
397405 const auto &bgmmc = pd ()->get_brgemm_matmul_conf ();
398406 const int num_threads = brgmm_ctx.get_num_threads_for_parallelization ();
399407
@@ -412,6 +420,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
412420 bmn_end);
413421 balance211 (bmn_end - bmn_start, nthr_k, ithr_k, start, end);
414422
423+ int prev_ker_idx = -1 ;
424+
415425 int b {0 }, mc {0 }, nc {0 };
416426
417427 assert (bgmmc.batch == 1 );
@@ -450,6 +460,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
450460 = (bgmmc.N - nb * bgmmc.N_blk < bgmmc.N_blk );
451461 const int brg_ker_idx = pd ()->get_brg_kernel_idx (
452462 false , false , is_M_tail, is_N_tail, false );
463+ maybe_tile_configure (is_amx, brg_kernel_palettes_,
464+ brg_ker_idx, prev_ker_idx);
453465 const auto brg_kernel = brg_kernels_[brg_ker_idx].get ();
454466 const int m = mb * bgmmc.M_blk ;
455467 const int n = nb * bgmmc.N_blk ;
0 commit comments