@@ -305,7 +305,8 @@ void brgemm_matmul_t<isa>::compute_kernel(
305305 ? brgmm_ctx.get_buf_C_ptr (ithr, m_blk_idx, n_blk_idx)
306306 : ptr_D;
307307
308- const auto zp_comp_a = brgmm_ctx.get_zp_a_compensation_ptr (ithr, n_blk_idx);
308+ const auto zp_comp_a
309+ = brgmm_ctx.get_zp_a_compensation_ptr (ithr, b_idx, n_blk_idx);
309310 const auto zp_comp_b
310311 = brgmm_ctx.get_zp_b_compensation_result_ptr (ithr, m_blk_idx);
311312 const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr ();
@@ -475,7 +476,8 @@ void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
475476 // TODO: support reduction for zp/s8s8 compensations
476477 // computed in copy routines
477478 const auto zp_comp_a
478- = brgmm_ctx.get_zp_a_compensation_ptr (ithr, nb);
479+ = brgmm_ctx.get_zp_a_compensation_ptr (
480+ ithr, b, nb);
479481 const auto zp_comp_b
480482 = brgmm_ctx.get_zp_b_compensation_result_ptr (
481483 ithr, mb);
@@ -579,8 +581,8 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
579581 const int n = n_blk_idx * bgmmc.N_blk ;
580582 const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk );
581583 ctx.current_N_blk = is_N_tail ? bgmmc.N_tail : bgmmc.N_blk ;
582- ctx.zp_a_compensation_ptr
583- = ( void *)brgmm_ctx. get_zp_a_compensation_ptr ( ithr, n_blk_idx);
584+ ctx.zp_a_compensation_ptr = ( void *)brgmm_ctx. get_zp_a_compensation_ptr (
585+ ithr, b_idx , n_blk_idx);
584586 ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr ();
585587
586588 int gb = 0 ;
@@ -709,8 +711,10 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
709711 // multitreaded execution mode
710712 const size_t reorder_zp_a_comp_offset
711713 = weights_d.size () - weights_d.additional_buffer_size ();
714+ const size_t b_batch
715+ = get_bb_idx (bgmmc.batch - 1 , bgmmc_.bcast_B_desc ) + 1 ;
712716 const size_t s8s8_buffer_sz = bgmmc.s8s8_compensation_required
713- ? bgmmc. s8s8_comp_b_str * sizeof (int32_t )
717+ ? sizeof (int32_t ) * b_batch * bgmmc. s8s8_comp_b_str
714718 : 0 ;
715719 reorder_zp_a_comp_ptr_
716720 = const_cast <int32_t *>(reinterpret_cast <const int32_t *>(
@@ -965,7 +969,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
965969 ? n_blk_idx % bgmmc_.N_chunk_size
966970 : n_blk_idx;
967971 return s8s8_compensation_ptr_ + ithr * bgmmc_.s8s8_comp_ithr_str
968- + b * bgmmc_.s8s8_comp_b_str
972+ + get_bb_idx (b, bgmmc_. bcast_B_desc ) * bgmmc_.s8s8_comp_b_str
969973 + n_blk_local * bgmmc_.s8s8_comp_n_str ;
970974 }
971975
@@ -987,7 +991,8 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
987991
988992 const int32_t *get_zp_c_val_ptr () const { return &zero_point_c_val_; }
989993
990- int32_t *get_zp_a_compensation_ptr (int ithr, int n_blk_idx) const {
994+ int32_t *get_zp_a_compensation_ptr (
995+ int ithr, int b_idx, int n_blk_idx) const {
991996 if (!bgmmc_.has_zero_point_a ) return nullptr ;
992997
993998 const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size ;
@@ -1000,7 +1005,9 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
10001005 // locally just before usage. Using the single global scaling before
10011006 // parallel section might produce significant overhead for small
10021007 // problems running in multitreaded execution mode
1003- const int base_offset = n_blk_idx * bgmmc_.wei_n_blk ;
1008+ const int base_offset = get_bb_idx (b_idx, bgmmc_.bcast_B_desc )
1009+ * rnd_up (bgmmc_.N , bgmmc_.wei_n_blk )
1010+ + n_blk_idx * bgmmc_.wei_n_blk ;
10041011 PRAGMA_OMP_SIMD ()
10051012 for (int b = 0 ; b < bgmmc_.wei_n_blk ; b++)
10061013 zp_comp[b] = -zero_point_a_negative_val_
0 commit comments