Skip to content

Commit 28ddb5b

Browse files
committed
x64: brgemm matmul: fix tile configuration
1 parent 05629a5 commit 28ddb5b

File tree

2 files changed

+37
-25
lines changed

2 files changed

+37
-25
lines changed

src/cpu/x64/matmul/brgemm_matmul.cpp

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,22 @@ using namespace nstl;
4040

4141
using namespace data_type;
4242

43+
namespace {
44+
void maybe_tile_configure(bool is_amx,
45+
const char brg_kernel_palettes[][AMX_PALETTE_SIZE], int brg_ker_idx,
46+
int &prev_ker_idx) {
47+
if (!is_amx) return;
48+
if (brg_ker_idx == prev_ker_idx) return;
49+
// TODO: more accurately estimate the costs of memcmp and tile configuration
50+
if (prev_ker_idx == -1
51+
|| std::memcmp(&brg_kernel_palettes[brg_ker_idx][0],
52+
&brg_kernel_palettes[prev_ker_idx][0], AMX_PALETTE_SIZE)
53+
!= 0)
54+
amx_tile_configure(&brg_kernel_palettes[brg_ker_idx][0]);
55+
prev_ker_idx = brg_ker_idx;
56+
}
57+
} // namespace
58+
4359
template <cpu_isa_t isa>
4460
status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
4561
const auto src_dt = src_md_.data_type;
@@ -119,8 +135,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
119135
brgemm_attr_t brgattr;
120136
brgattr.generate_skip_accumulation
121137
= bgmmc_.post_ops_applicable && bgmmc_.nthr_k > 1;
122-
constexpr bool is_amx = one_of(
123-
isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
138+
const bool is_amx = is_superset(isa, avx512_core_amx);
124139
if (is_amx) {
125140
if (!brgattr.generate_skip_accumulation) {
126141
// TODO: uker doesn't yet support generate_skip_accumulation
@@ -217,10 +232,9 @@ status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
217232
balance211((int)bgmmc.K_chunks, brgmm_ctx.get_num_threads_for_k(),
218233
ithr_k, kc_start, kc_end);
219234

220-
if (is_amx) {
221-
const auto base_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();
222-
amx_tile_configure(&brg_kernel_palettes_[base_ker_idx][0]);
223-
}
235+
int prev_ker_idx = -1;
236+
maybe_tile_configure(is_amx, brg_kernel_palettes_,
237+
brgmm_ctx.get_base_brgemm_kernel_idx(), prev_ker_idx);
224238

225239
int b {0}, mc {0}, nc {0};
226240
nd_iterator_init(
@@ -239,8 +253,8 @@ status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
239253
for (int mb = m_start; mb < m_end; mb++) {
240254
if (use_buffer_a && nb == n_start)
241255
copy_a_chunk_in_buffer(brgmm_ctx, ithr, b, mb, kc);
242-
compute_kernel(
243-
brgmm_ctx, ithr, b, mb, nb, kc, kc == kc_start);
256+
compute_kernel(brgmm_ctx, ithr, b, mb, nb, kc,
257+
kc == kc_start, prev_ker_idx);
244258
}
245259
}
246260
++start;
@@ -258,12 +272,12 @@ status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
258272
template <cpu_isa_t isa>
259273
void brgemm_matmul_t<isa>::compute_kernel(
260274
const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
261-
int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init) const {
275+
int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init,
276+
int &prev_ker_idx) const {
262277
constexpr bool is_amx
263278
= one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
264279
const auto &bgmmc = pd()->get_brgemm_matmul_conf();
265280
const auto addr_batch = brgmm_ctx.get_batch_elem_ptr(ithr);
266-
const int base_brg_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();
267281

268282
const auto wsp_tile = brgmm_ctx.get_tile_workspace(ithr);
269283
const int m = m_blk_idx * bgmmc.M_blk;
@@ -302,10 +316,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
302316
if (gemm_batch > 0 && brg_ker_idx >= 0) {
303317
const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
304318
assert(brg_kernel != nullptr);
305-
306-
const bool is_tile_reconf_required = is_amx && (is_M_tail || is_N_tail);
307-
if (is_tile_reconf_required)
308-
amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
319+
maybe_tile_configure(
320+
is_amx, brg_kernel_palettes_, brg_ker_idx, prev_ker_idx);
309321

310322
brgmm_ctx.init_brgemm_batch_elements_values(
311323
ithr, 0, gemm_batch, b_idx, m_blk_idx, k_blk_idx, n_blk_idx);
@@ -339,9 +351,6 @@ void brgemm_matmul_t<isa>::compute_kernel(
339351
brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch,
340352
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
341353
}
342-
343-
if (is_tile_reconf_required)
344-
amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
345354
}
346355
if (is_K_tail) {
347356
brgmm_ctx.init_brgemm_batch_elements_values(
@@ -350,11 +359,10 @@ void brgemm_matmul_t<isa>::compute_kernel(
350359
const bool use_init_ker = (do_init && gemm_batch == 0);
351360
const int brg_ker_idx = pd()->get_brg_kernel_idx(
352361
false, use_init_ker, is_M_tail, is_N_tail, true);
362+
maybe_tile_configure(
363+
is_amx, brg_kernel_palettes_, brg_ker_idx, prev_ker_idx);
353364
const auto brg_kernel_k_tail = brg_kernels_[brg_ker_idx].get();
354-
const bool is_tile_reconf_required
355-
= is_amx && bgmmc.K_tail != bgmmc.K_blk;
356-
if (is_tile_reconf_required)
357-
amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
365+
358366
if (post_ops_applicable) {
359367
void *scratch = is_amx
360368
? static_cast<void *>(wsp_tile)
@@ -384,8 +392,6 @@ void brgemm_matmul_t<isa>::compute_kernel(
384392
brgemm_kernel_execute(brg_kernel_k_tail, 1, addr_batch,
385393
(void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
386394
}
387-
if (is_tile_reconf_required)
388-
amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
389395
}
390396
}
391397

@@ -394,6 +400,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
394400
const brg_matmul_exec_ctx_t &brgmm_ctx) const {
395401
if (!brgmm_ctx.parallel_reduction_is_used()) return;
396402

403+
const bool is_amx = is_superset(isa, avx512_core_amx);
404+
397405
const auto &bgmmc = pd()->get_brgemm_matmul_conf();
398406
const int num_threads = brgmm_ctx.get_num_threads_for_parallelization();
399407

@@ -412,6 +420,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
412420
bmn_end);
413421
balance211(bmn_end - bmn_start, nthr_k, ithr_k, start, end);
414422

423+
int prev_ker_idx = -1;
424+
415425
int b {0}, mc {0}, nc {0};
416426

417427
assert(bgmmc.batch == 1);
@@ -450,6 +460,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
450460
= (bgmmc.N - nb * bgmmc.N_blk < bgmmc.N_blk);
451461
const int brg_ker_idx = pd()->get_brg_kernel_idx(
452462
false, false, is_M_tail, is_N_tail, false);
463+
maybe_tile_configure(is_amx, brg_kernel_palettes_,
464+
brg_ker_idx, prev_ker_idx);
453465
const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
454466
const int m = mb * bgmmc.M_blk;
455467
const int n = nb * bgmmc.N_blk;

src/cpu/x64/matmul/brgemm_matmul.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct brgemm_matmul_t : public primitive_t {
105105
status_t execute_body(const exec_ctx_t &ctx) const;
106106
void compute_kernel(const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr,
107107
int b_idx, int m_blk_idx, int n_blk_idx, int k_blk_idx,
108-
bool do_init) const;
108+
bool do_init, int &prev_ker_idx) const;
109109
void copy_a_chunk_in_buffer(const brg_matmul_exec_ctx_t &brgmm_ctx,
110110
int ithr, int b_idx, int m_blk_idx, int k_blk_idx) const;
111111
void copy_b_chunk_in_buffer(const brg_matmul_exec_ctx_t &brgmm_ctx,
@@ -116,7 +116,7 @@ struct brgemm_matmul_t : public primitive_t {
116116
char *result_ptr, const char *reduce_ptr, size_t size) const;
117117

118118
std::unique_ptr<brgemm_kernel_t> brg_kernels_[max_num_brg_kernels_matmul];
119-
char brg_kernel_palettes_[max_num_brg_kernels_matmul][64];
119+
alignas(64) char brg_kernel_palettes_[max_num_brg_kernels_matmul][64];
120120
std::unique_ptr<jit_brgemm_matmul_copy_b_t> copy_B_kernel_;
121121
std::unique_ptr<jit_brgemm_matmul_copy_a_t> copy_A_kernel_;
122122
std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_f32_;

0 commit comments

Comments
 (0)