Skip to content

Commit 8c20f62

Browse files
committed
x64: brgemm matmul: enable blocked B for 3D problems
1 parent acb8e12 commit 8c20f62

File tree

3 files changed

+63
-23
lines changed

3 files changed

+63
-23
lines changed

src/cpu/x64/matmul/brgemm_matmul.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
305305
? brgmm_ctx.get_buf_C_ptr(ithr, m_blk_idx, n_blk_idx)
306306
: ptr_D;
307307

308-
const auto zp_comp_a = brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx);
308+
const auto zp_comp_a
309+
= brgmm_ctx.get_zp_a_compensation_ptr(ithr, b_idx, n_blk_idx);
309310
const auto zp_comp_b
310311
= brgmm_ctx.get_zp_b_compensation_result_ptr(ithr, m_blk_idx);
311312
const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
@@ -475,7 +476,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
475476
// TODO: support reduction for zp/s8s8 compensations
476477
// computed in copy routines
477478
const auto zp_comp_a
478-
= brgmm_ctx.get_zp_a_compensation_ptr(ithr, nb);
479+
= brgmm_ctx.get_zp_a_compensation_ptr(
480+
ithr, b, nb);
479481
const auto zp_comp_b
480482
= brgmm_ctx.get_zp_b_compensation_result_ptr(
481483
ithr, mb);
@@ -579,8 +581,8 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
579581
const int n = n_blk_idx * bgmmc.N_blk;
580582
const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk);
581583
ctx.current_N_blk = is_N_tail ? bgmmc.N_tail : bgmmc.N_blk;
582-
ctx.zp_a_compensation_ptr
583-
= (void *)brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx);
584+
ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr(
585+
ithr, b_idx, n_blk_idx);
584586
ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr();
585587

586588
int gb = 0;
@@ -709,8 +711,10 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
709711
// multitreaded execution mode
710712
const size_t reorder_zp_a_comp_offset
711713
= weights_d.size() - weights_d.additional_buffer_size();
714+
const size_t b_batch
715+
= get_bb_idx(bgmmc.batch - 1, bgmmc_.bcast_B_desc) + 1;
712716
const size_t s8s8_buffer_sz = bgmmc.s8s8_compensation_required
713-
? bgmmc.s8s8_comp_b_str * sizeof(int32_t)
717+
? sizeof(int32_t) * b_batch * bgmmc.s8s8_comp_b_str
714718
: 0;
715719
reorder_zp_a_comp_ptr_
716720
= const_cast<int32_t *>(reinterpret_cast<const int32_t *>(
@@ -965,7 +969,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
965969
? n_blk_idx % bgmmc_.N_chunk_size
966970
: n_blk_idx;
967971
return s8s8_compensation_ptr_ + ithr * bgmmc_.s8s8_comp_ithr_str
968-
+ b * bgmmc_.s8s8_comp_b_str
972+
+ get_bb_idx(b, bgmmc_.bcast_B_desc) * bgmmc_.s8s8_comp_b_str
969973
+ n_blk_local * bgmmc_.s8s8_comp_n_str;
970974
}
971975

@@ -987,7 +991,8 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
987991

988992
const int32_t *get_zp_c_val_ptr() const { return &zero_point_c_val_; }
989993

990-
int32_t *get_zp_a_compensation_ptr(int ithr, int n_blk_idx) const {
994+
int32_t *get_zp_a_compensation_ptr(
995+
int ithr, int b_idx, int n_blk_idx) const {
991996
if (!bgmmc_.has_zero_point_a) return nullptr;
992997

993998
const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size;
@@ -1000,7 +1005,9 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
10001005
// locally just before usage. Using the single global scaling before
10011006
// parallel section might produce significant overhead for small
10021007
// problems running in multitreaded execution mode
1003-
const int base_offset = n_blk_idx * bgmmc_.wei_n_blk;
1008+
const int base_offset = get_bb_idx(b_idx, bgmmc_.bcast_B_desc)
1009+
* rnd_up(bgmmc_.N, bgmmc_.wei_n_blk)
1010+
+ n_blk_idx * bgmmc_.wei_n_blk;
10041011
PRAGMA_OMP_SIMD()
10051012
for (int b = 0; b < bgmmc_.wei_n_blk; b++)
10061013
zp_comp[b] = -zero_point_a_negative_val_

src/cpu/x64/matmul/brgemm_matmul_utils.cpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,27 @@ int get_default_n_block(format_tag_t matrix_b_tag) {
4242
// Note: consider using weights mem_descriptor 'inner_blks' to
4343
// return B's inner block for non-default cases.
4444
switch (matrix_b_tag) {
45+
case aCB16b64c:
46+
case aCB16b64c2b:
47+
case aCB16b64c4b:
4548
case BA16a64b4a:
4649
case BA16a64b2a:
4750
case BA16a64b: return 64;
51+
case aCB16b48c:
52+
case aCB16b48c2b:
53+
case aCB16b48c4b:
4854
case BA16a48b:
4955
case BA16a48b2a:
5056
case BA16a48b4a: return 48;
57+
case aCB16b32c:
58+
case aCB16b32c2b:
59+
case aCB16b32c4b:
5160
case BA16a32b:
5261
case BA16a32b2a:
5362
case BA16a32b4a: return 32;
63+
case aCB16b16c:
64+
case aCB16b16c2b:
65+
case aCB16b16c4b:
5466
case BA16a16b:
5567
case BA16a16b2a:
5668
case BA16a16b4a: return 16;
@@ -242,14 +254,17 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md,
242254
status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const {
243255

244256
memory_desc_t want_B_md = B_md;
257+
// Set bits for all dimensions except k dimension
258+
const int compensation_mask
259+
= ((1 << bgmmc.ndims) - 1 - (1 << (bgmmc.ndims - 2)));
245260
if (bgmmc.s8s8_compensation_required && bgmmc.blocked_B) {
246261
want_B_md.extra.flags |= memory_extra_flags::compensation_conv_s8s8;
247-
want_B_md.extra.compensation_mask = (1 << 1);
262+
want_B_md.extra.compensation_mask = compensation_mask;
248263
}
249264
if (bgmmc.src_zp_type != brgemm_broadcast_t::none && bgmmc.blocked_B) {
250265
want_B_md.extra.flags
251266
|= memory_extra_flags::compensation_conv_asymmetric_src;
252-
want_B_md.extra.asymm_compensation_mask = (1 << 1);
267+
want_B_md.extra.asymm_compensation_mask = compensation_mask;
253268
}
254269

255270
if (B_any_layout) {
@@ -262,27 +277,29 @@ status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const {
262277

263278
format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
264279
int n_blk) const {
265-
if (bgmmc.ndims > 2) return format_tag::undef;
280+
281+
if (bgmmc.ndims > 3) return format_tag::undef;
266282
if (this->is_int8()) switch (n_blk) {
267-
case 64: return BA16a64b4a;
268-
case 48: return BA16a48b4a;
269-
case 32: return BA16a32b4a;
270-
case 16: return BA16a16b4a;
283+
case 64: return bgmmc.ndims == 3 ? aCB16b64c4b : BA16a64b4a;
284+
case 48: return bgmmc.ndims == 3 ? aCB16b48c4b : BA16a48b4a;
285+
case 32: return bgmmc.ndims == 3 ? aCB16b32c4b : BA16a32b4a;
286+
case 16: return bgmmc.ndims == 3 ? aCB16b16c4b : BA16a16b4a;
271287
default: return format_tag::undef;
272288
}
289+
273290
if (this->is_bf16()) switch (n_blk) {
274-
case 64: return BA16a64b2a;
275-
case 48: return BA16a48b2a;
276-
case 32: return BA16a32b2a;
277-
case 16: return BA16a16b2a;
291+
case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a;
292+
case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a;
293+
case 32: return bgmmc.ndims == 3 ? aCB16b32c2b : BA16a32b2a;
294+
case 16: return bgmmc.ndims == 3 ? aCB16b16c2b : BA16a16b2a;
278295
default: return format_tag::undef;
279296
}
280297
// Note: bf32 assumes f32 blocking
281298
if (this->is_f32() || this->is_bf32()) switch (n_blk) {
282-
case 64: return BA16a64b;
283-
case 48: return BA16a48b;
284-
case 32: return BA16a32b;
285-
case 16: return BA16a16b;
299+
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
300+
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
301+
case 32: return bgmmc.ndims == 3 ? aCB16b32c : BA16a32b;
302+
case 16: return bgmmc.ndims == 3 ? aCB16b16c : BA16a16b;
286303
default: return format_tag::undef;
287304
}
288305
return format_tag::undef;

tests/benchdnn/inputs/matmul/harness_matmul_data_tags

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
--attr-fpmath=,bf16
1414
--wtag=BA16a64b,BA16a48b,BA16a32b,BA16a16b
1515
--batch=shapes_2d
16+
--attr-fpmath=
1617

1718
--cfg=bf16bf16bf16
1819
--wtag=BA16a64b2a,BA16a48b2a,BA16a32b2a,BA16a16b2a
@@ -21,3 +22,18 @@
2122
--cfg=u8s8f32
2223
--wtag=BA16a64b4a,BA16a48b4a,BA16a32b4a,BA16a16b4a
2324
--batch=shapes_2d
25+
26+
--stag=abc --dtag=abc
27+
--cfg=f32
28+
--attr-fpmath=,bf16
29+
--wtag=aCB16b16c,aCB16b32c,aCB16b48c,aCB16b64c
30+
--batch=shapes_3d
31+
--attr-fpmath=
32+
33+
--cfg=bf16bf16bf16
34+
--wtag=aCB16b16c2b,aCB16b32c2b,aCB16b48c2b,aCB16b64c2b
35+
--batch=shapes_3d
36+
37+
--cfg=u8s8f32
38+
--wtag=aCB16b16c4b,aCB16b32c4b,aCB16b48c4b,aCB16b64c4b
39+
--batch=shapes_3d

0 commit comments

Comments
 (0)