@@ -22,8 +22,6 @@ using namespace dnnl::impl::cpu::x64;
2222namespace ov {
2323namespace intel_cpu {
2424
25- const bool transpose_b_enable = std::getenv(" TRANSPOSE_B" );
26-
2725jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter (jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr)
2826 : jit_emitter(h, isa) {
2927 in_out_type_ = emitter_in_out_map::gpr_to_gpr;
@@ -43,16 +41,17 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
4341 const auto transposed_shape = layout.empty () ? original_shape : snippets::utils::get_planar_vdims (original_shape, layout);
4442 const size_t N = transposed_shape.back ();
4543
46- dnnl_format_tag_t format = dnnl_abcd;
44+ m_format = dnnl_abcd;
4745 size_t wei_stride = 0 ;
4846 if (layout == VectorDims{0 , 2 , 1 , 3 }) {
49- format = dnnl_acbd;
47+ std::cout << " wei stride is needed!!\n " ;
48+ m_format = dnnl_acbd;
5049 wei_stride = jit_brgemm_emitter::get_in_leading_dim (original_shape, layout);
50+ } else if (layout == VectorDims{0 , 1 , 3 , 2 }) {
51+ std::cout << " transposed copy_b shape\n " ;
52+ m_format = dnnl_abdc;
5153 }
5254
53- if (transpose_b_enable)
54- format = dnnl_abdc;
55-
5655 std::cout << " [ INFO ] CopyBEmitter is processing...\n " ;
5756 std::cout << " \t shape = " << ov::PartialShape (original_shape) << std::endl;
5857 std::cout << " \t layout = " << ov::PartialShape (layout) << std::endl;
@@ -77,7 +76,7 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
7776 OV_CPU_JIT_EMITTER_ASSERT (!one_of (m_brg_weight_etype, element::bf16 , element::i8 ), " doesn't support precision " , m_brg_weight_etype);
7877 const auto repacking_buffer_shape = brgemm_repack->get_repacking_buffer_shape ();
7978 OV_CPU_JIT_EMITTER_ASSERT (!repacking_buffer_shape.empty (), " Repacking buffer shape mustn't be empty" );
80- size_t LDB = transpose_b_enable ? 384 : repacking_buffer_shape.back ();
79+ size_t LDB = repacking_buffer_shape.back ();
8180 if (auto val = std::getenv (" LDB" )) {
8281 LDB = std::atoi (val);
8382 }
@@ -93,18 +92,17 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
9392 const auto src_dt = static_cast <dnnl_data_type_t >(DnnlExtensionUtils::ElementTypeToDataType (brg_src_etype));
9493 const auto wei_dt = static_cast <dnnl_data_type_t >(DnnlExtensionUtils::ElementTypeToDataType (m_brg_weight_etype));
9594
96- init_brgemm_copy (m_kernel, format, wei_stride, N, m_inner_N_block, m_inner_N_tail, LDB, m_K_blk, use_amx, src_dt, wei_dt);
95+ init_brgemm_copy (m_kernel, N, m_inner_N_block, m_inner_N_tail, LDB, m_K_blk, use_amx, src_dt, wei_dt, wei_stride );
9796}
9897
9998void jit_brgemm_copy_b_emitter::init_brgemm_copy (std::unique_ptr<matmul::jit_brgemm_matmul_copy_b_t >& kernel,
100- dnnl_format_tag_t format, size_t wei_stride,
10199 size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K,
102- bool is_with_amx, dnnl_data_type_t src_dt, dnnl_data_type_t wei_dt) const {
100+ bool is_with_amx, dnnl_data_type_t src_dt, dnnl_data_type_t wei_dt, size_t wei_stride ) const {
103101 matmul::brgemm_matmul_conf_t brgCopyKernelConf;
104102 brgCopyKernelConf.src_dt = src_dt;
105103 brgCopyKernelConf.wei_dt = wei_dt;
106104 brgCopyKernelConf.wei_n_blk = static_cast <int >(N_blk);
107- brgCopyKernelConf.wei_tag = format ;
105+ brgCopyKernelConf.wei_tag = m_format ;
108106 brgCopyKernelConf.copy_B_wei_stride = wei_stride;
109107 brgCopyKernelConf.LDB = static_cast <dim_t >(LDB);
110108 brgCopyKernelConf.N = static_cast <dim_t >(N);
@@ -148,10 +146,12 @@ void jit_brgemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const s
148146 Xbyak::Reg64 comp (static_cast <int >(m_with_comp ? out[1 ] : 0 ));
149147
150148 const size_t data_size = m_brg_weight_etype.size ();
149+ const size_t K_scale = m_format == dnnl_abdc ? m_K_blk : 1 ;
151150 for (size_t nb = 0 ; nb < div_up (m_N_blk, m_inner_N_block); nb++) {
152- const size_t offset_in = m_in_offset + nb * m_inner_N_block * data_size;
151+ const size_t offset_in = m_in_offset + nb * m_inner_N_block * K_scale * data_size;
153152 const size_t offset_out = m_out_offset + nb * m_inner_N_block * m_brgemmVNNIFactor * data_size;
154153 const size_t offset_comp = m_with_comp ? m_comp_offset + nb * m_inner_N_block * sizeof (int32_t ) : 0 ;
154+ std::cout << " offset in [" << nb << " ] = " << offset_in << std::endl;
155155
156156 const bool is_N_tail = (m_N_blk - nb * m_inner_N_block < m_inner_N_block);
157157 const auto current_N_blk = is_N_tail ? m_inner_N_tail : m_inner_N_block;
0 commit comments