Skip to content

Commit 4768cf9

Browse files
committed
CopyB with transpose works for u8i8 case
1 parent 89e153e commit 4768cf9

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs
7676
ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()),
7777
ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()));
7878

79-
const bool skip_transpose_b_extraction = std::getenv("TRANSPOSE_B");
8079
register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_matmul0, matcher_name),
8180
[=](ov::pass::pattern::Matcher &m) {
8281
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs")
@@ -92,6 +91,7 @@ ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs
9291
matmul->set_transpose_a(false);
9392
rewritten |= true;
9493
}
94+
const bool skip_transpose_b_extraction = std::getenv("TRANSPOSE_B") && matmul->get_input_element_type(1) != ov::element::f32;
9595
if (matmul->get_transpose_b() && !skip_transpose_b_extraction) {
9696
extract(matmul->input(1));
9797
std::cout << "[ INFO ] ExplicitTransposeMatMulInputs is finished for B input\n";

src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ using namespace dnnl::impl::cpu::x64;
2222
namespace ov {
2323
namespace intel_cpu {
2424

25-
const bool transpose_b_enable = std::getenv("TRANSPOSE_B");
26-
2725
jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr)
2826
: jit_emitter(h, isa) {
2927
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
@@ -43,16 +41,17 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
4341
const auto transposed_shape = layout.empty() ? original_shape : snippets::utils::get_planar_vdims(original_shape, layout);
4442
const size_t N = transposed_shape.back();
4543

46-
dnnl_format_tag_t format = dnnl_abcd;
44+
m_format = dnnl_abcd;
4745
size_t wei_stride = 0;
4846
if (layout == VectorDims{0, 2, 1, 3}) {
49-
format = dnnl_acbd;
47+
std::cout << "wei stride is needed!!\n";
48+
m_format = dnnl_acbd;
5049
wei_stride = jit_brgemm_emitter::get_in_leading_dim(original_shape, layout);
50+
} else if (layout == VectorDims{0, 1, 3, 2}) {
51+
std::cout << "transposed copy_b shape\n";
52+
m_format = dnnl_abdc;
5153
}
5254

53-
if (transpose_b_enable)
54-
format = dnnl_abdc;
55-
5655
std::cout << "[ INFO ] CopyBEmitter is processing...\n";
5756
std::cout << "\tshape = " << ov::PartialShape(original_shape) << std::endl;
5857
std::cout << "\tlayout = " << ov::PartialShape(layout) << std::endl;
@@ -77,7 +76,7 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
7776
OV_CPU_JIT_EMITTER_ASSERT(!one_of(m_brg_weight_etype, element::bf16, element::i8), "doesn't support precision ", m_brg_weight_etype);
7877
const auto repacking_buffer_shape = brgemm_repack->get_repacking_buffer_shape();
7978
OV_CPU_JIT_EMITTER_ASSERT(!repacking_buffer_shape.empty(), "Repacking buffer shape mustn't be empty");
80-
size_t LDB = transpose_b_enable ? 384 : repacking_buffer_shape.back();
79+
size_t LDB = repacking_buffer_shape.back();
8180
if (auto val = std::getenv("LDB")) {
8281
LDB = std::atoi(val);
8382
}
@@ -93,18 +92,17 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
9392
const auto src_dt = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(brg_src_etype));
9493
const auto wei_dt = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(m_brg_weight_etype));
9594

96-
init_brgemm_copy(m_kernel, format, wei_stride, N, m_inner_N_block, m_inner_N_tail, LDB, m_K_blk, use_amx, src_dt, wei_dt);
95+
init_brgemm_copy(m_kernel, N, m_inner_N_block, m_inner_N_tail, LDB, m_K_blk, use_amx, src_dt, wei_dt, wei_stride);
9796
}
9897

9998
void jit_brgemm_copy_b_emitter::init_brgemm_copy(std::unique_ptr<matmul::jit_brgemm_matmul_copy_b_t>& kernel,
100-
dnnl_format_tag_t format, size_t wei_stride,
10199
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K,
102-
bool is_with_amx, dnnl_data_type_t src_dt, dnnl_data_type_t wei_dt) const {
100+
bool is_with_amx, dnnl_data_type_t src_dt, dnnl_data_type_t wei_dt, size_t wei_stride) const {
103101
matmul::brgemm_matmul_conf_t brgCopyKernelConf;
104102
brgCopyKernelConf.src_dt = src_dt;
105103
brgCopyKernelConf.wei_dt = wei_dt;
106104
brgCopyKernelConf.wei_n_blk = static_cast<int>(N_blk);
107-
brgCopyKernelConf.wei_tag = format;
105+
brgCopyKernelConf.wei_tag = m_format;
108106
brgCopyKernelConf.copy_B_wei_stride = wei_stride;
109107
brgCopyKernelConf.LDB = static_cast<dim_t>(LDB);
110108
brgCopyKernelConf.N = static_cast<dim_t>(N);
@@ -148,10 +146,12 @@ void jit_brgemm_copy_b_emitter::emit_impl(const std::vector<size_t>& in, const s
148146
Xbyak::Reg64 comp(static_cast<int>(m_with_comp ? out[1] : 0));
149147

150148
const size_t data_size = m_brg_weight_etype.size();
149+
const size_t K_scale = m_format == dnnl_abdc ? m_K_blk : 1;
151150
for (size_t nb = 0; nb < div_up(m_N_blk, m_inner_N_block); nb++) {
152-
const size_t offset_in = m_in_offset + nb * m_inner_N_block * data_size;
151+
const size_t offset_in = m_in_offset + nb * m_inner_N_block * K_scale * data_size;
153152
const size_t offset_out = m_out_offset + nb * m_inner_N_block * m_brgemmVNNIFactor * data_size;
154153
const size_t offset_comp = m_with_comp ? m_comp_offset + nb * m_inner_N_block * sizeof(int32_t) : 0;
154+
std::cout << "offset in [" << nb << "] = " << offset_in << std::endl;
155155

156156
const bool is_N_tail = (m_N_blk - nb * m_inner_N_block < m_inner_N_block);
157157
const auto current_N_blk = is_N_tail ? m_inner_N_tail : m_inner_N_block;

src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {
2727
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
2828

2929
void init_brgemm_copy(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t>& kernel,
30-
dnnl_format_tag_t format, size_t wei_stride,
3130
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K,
32-
bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const;
31+
bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1, size_t wei_stride) const;
3332
void emit_kernel_call(const dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel,
3433
Xbyak::Reg64 src, Xbyak::Reg64 dst, Xbyak::Reg64 comp, size_t N, size_t K,
3534
size_t offset_in, size_t offset_out, size_t offset_comp) const;
@@ -56,6 +55,8 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {
5655

5756
bool m_with_comp = false;
5857

58+
dnnl_format_tag_t m_format;
59+
5960
#ifdef SNIPPETS_DEBUG_CAPS
6061
friend std::string init_info_jit_brgemm_copy_b_emitter(const jit_brgemm_copy_b_emitter *emitter);
6162
#endif

0 commit comments

Comments
 (0)