Skip to content

Commit e0c985f

Browse files
committed
Cleanup
1 parent 1c77cfe commit e0c985f

File tree

7 files changed

+49
-20
lines changed

7 files changed

+49
-20
lines changed

src/common/snippets/include/snippets/lowered/pass/brgemm_blocking.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ class BrgemmBlockingBase : public snippets::lowered::pass::RangedPass {
2828
/**
2929
* @interface mark_blocking_loops
3030
* @brief Covers brgemm with blocking loops. Also should calculate optimal blocking parameters inside.
31-
* @param linear_ir LIR that contain's brgemm
31+
* @param linear_ir LIR that contains brgemm
3232
* @param brgemm_it iterator on brgemm expression which should be covered with blocking loops
3333
*/
3434
virtual bool mark_blocking_loops(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it) = 0;
35-
// virtual std::tuple<size_t, size_t, size_t> get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) = 0;
3635

3736
static bool blocking_loop_exists(const snippets::lowered::LoopManagerPtr& loop_manager,
3837
const ov::snippets::lowered::ExpressionPtr& brgemm_expr,

src/common/snippets/include/snippets/pass/explicit_transpose_matmul_inputs.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class ExplicitTransposeMatMulInputs: public ov::pass::MatcherPass {
3030
static bool are_weights_scalar(const std::shared_ptr<ov::Node>& node);
3131

3232
private:
33-
static bool extract(const ov::Input<ov::Node>& input);
33+
static void extract(const ov::Input<ov::Node>& input);
3434
};
3535

3636
} // namespace pass

src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ bool ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(const
1919
});
2020
}
2121

22-
bool ov::snippets::pass::ExplicitTransposeMatMulInputs::extract(const ov::Input<ov::Node>& input) {
22+
void ov::snippets::pass::ExplicitTransposeMatMulInputs::extract(const ov::Input<ov::Node>& input) {
2323
auto parent = input.get_source_output().get_node_shared_ptr();
2424
auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent);
2525
while (!transpose && !ov::is_type<ov::op::v0::Parameter>(parent)) {
@@ -47,7 +47,7 @@ bool ov::snippets::pass::ExplicitTransposeMatMulInputs::extract(const ov::Input<
4747
new_transpose_order->set_friendly_name(transpose_pattern->get_friendly_name());
4848
ov::copy_runtime_info(transpose_pattern, new_transpose_order);
4949
transpose->set_argument(1, new_transpose_order);
50-
return true;
50+
return;
5151
}
5252

5353
// Create new Transpose before Parameter
@@ -68,7 +68,6 @@ bool ov::snippets::pass::ExplicitTransposeMatMulInputs::extract(const ov::Input<
6868
const auto new_transpose = std::make_shared<opset1::Transpose>(parent, constant_order); // parent is Parameter
6969
const auto consumer_input = *(consumers.begin());
7070
consumer_input.replace_source_output(new_transpose);
71-
return true;
7271
}
7372

7473
ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs() {
@@ -86,12 +85,14 @@ ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs
8685
if (!matmul)
8786
return false;
8887

89-
if (matmul->get_transpose_a() && extract(matmul->input(0))) {
88+
if (matmul->get_transpose_a()) {
89+
extract(matmul->input(0));
9090
matmul->set_transpose_a(false);
9191
rewritten |= true;
9292
}
9393

94-
if (matmul->get_transpose_b() && !transformation_callback(matmul) && extract(matmul->input(1))) {
94+
if (matmul->get_transpose_b() && !transformation_callback(matmul)) {
95+
extract(matmul->input(1));
9596
matmul->set_transpose_b(false);
9697
rewritten |= true;
9798
}

src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ size_t jit_brgemm_copy_b_emitter::compute_vnni_factor(const ov::element::Type& p
7373
}
7474

7575
size_t jit_brgemm_copy_b_emitter::get_elems_in_vec(const ov::element::Type& precision) {
76-
OV_CPU_JIT_EMITTER_ASSERT(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core), "doesn't support non avx512 platforms");
77-
const auto vlen = dnnl::impl::cpu::x64::cpu_isa_traits<dnnl::impl::cpu::x64::avx512_core>::vlen;
76+
using namespace dnnl::impl::cpu;
77+
OV_CPU_JIT_EMITTER_ASSERT(x64::mayiuse(x64::avx2), "doesn't support non avx512 platforms");
78+
const auto vlen = x64::mayiuse(avx512_core) ? x64::cpu_isa_traits<x64::avx512_core>::vlen : x64::cpu_isa_traits<x64::avx2>::vlen;
7879
return vlen / precision.size();
7980
}
8081

@@ -84,6 +85,7 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
8485
const auto brgemm_repack = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyB>(expr->get_node());
8586
if (!brgemm_repack)
8687
OV_CPU_JIT_EMITTER_THROW("expects BrgemmCopyB node");
88+
OV_CPU_JIT_EMITTER_ASSERT(is_superset(host_isa_, cpu::x64::avx2), "host_isa must be at least avx2");
8789
m_with_comp = with_compensations(brgemm_repack->get_type());
8890
m_in_offset = brgemm_repack->get_offset_in();
8991
m_out_offset = brgemm_repack->get_offset_out();

src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,36 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {
2525
return {{element::i8}, {element::bf16}, {element::f32}};
2626
}
2727

28+
/**
29+
* @brief Computes buffer size that OneDNN impl needs for repacked tensor
30+
* @param copy_b_expr Repacking expression whose information (tensor precision, layout, subtensors) is used for
31+
* buffer size computations
32+
*/
2833
static size_t get_repacking_buffer_size(const ov::snippets::lowered::ExpressionPtr& copy_b_expr);
34+
/**
35+
* @brief Computes buffer size that OneDNN impl needs for compensations
36+
* @param copy_b_expr Repacking expression whose information (tensor precision, subtensors) is used for
37+
* buffer size computations
38+
*/
2939
static size_t get_compensations_buffer_size(const ov::snippets::lowered::ExpressionPtr& copy_b_expr);
3040

41+
/**
42+
* @brief Computes leading dimension (LDB) which must be used in brgemm and brgemm_copy_b emitters
43+
* @param n_block N block size shared between BrgemmCPU and BrgemmCopyB node
44+
* @param precision tensor precision
45+
*/
3146
static size_t compute_repacking_out_leading_dim(const size_t n_block, const ov::element::Type& precision);
47+
/**
48+
* @brief Computes inner N block size used by OneDNN implementation. Depends on tensor precision
49+
*/
3250
static size_t compute_inner_n_block(const ov::element::Type& precision);
51+
/**
52+
* @brief Computes VNNI factor used by OneDNN implementation. Depends on tensor precision
53+
*/
3354
static size_t compute_vnni_factor(const ov::element::Type& precision);
55+
/**
56+
* @brief Computes number of elems with requested precision that fit in the vector register
57+
*/
3458
static size_t get_elems_in_vec(const ov::element::Type& precision);
3559

3660
private:

src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "snippets/utils/utils.hpp"
1414
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
1515
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
16-
#include "transformations/tpp/x64/op/brgemm.hpp"
1716

1817

1918
namespace ov {
@@ -23,6 +22,7 @@ using LinearIR = snippets::lowered::LinearIR;
2322
using LoopPort = snippets::lowered::LoopPort;
2423
using ExpressionPtr = ov::snippets::lowered::ExpressionPtr;
2524
using namespace ov::snippets::lowered;
25+
using namespace ov::intel_cpu::brgemm_utils;
2626

2727
LinearIR::constExprIt BrgemmCPUBlocking::move_new_memory_buffer(LinearIR& linear_ir, const LinearIR::constExprIt& brgemm_it) {
2828
const auto& brgemm_expr = brgemm_it->get();
@@ -42,7 +42,7 @@ LinearIR::constExprIt BrgemmCPUBlocking::get_loop_begin_pos(LinearIR& linear_ir,
4242
const auto& brgemm_expr = *brgemm_it;
4343
const auto brgemm = ov::as_type_ptr<intel_cpu::BrgemmCPU>(brgemm_expr->get_node());
4444
OPENVINO_ASSERT(brgemm, "get_loop_begin_pos must be called only for BrgemmCPU expression");
45-
if (ov::intel_cpu::brgemm_utils::with_amx(brgemm->get_type()))
45+
if (with_amx(brgemm->get_type()))
4646
loop_begin_it = move_new_memory_buffer(linear_ir, brgemm_it);
4747
if (copy_b_expr)
4848
loop_begin_it = linear_ir.find(copy_b_expr);
@@ -72,7 +72,6 @@ bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir, const LinearIR:
7272
const auto& k = *in_0_planar_dims.rbegin();
7373
OPENVINO_ASSERT(k == *++in_1_planar_dims.rbegin(), "Brgemm input descriptors have different K dimension value.");
7474
const auto type = brgemm->get_type();
75-
const bool with_repacking = ov::intel_cpu::brgemm_utils::with_repacking(type);
7675

7776
// Ticket: 113745
7877
// TODO: extend block size selection heuristics
@@ -83,15 +82,15 @@ bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir, const LinearIR:
8382
// K blocking is disabled in dynamism by default
8483
if (ov::snippets::utils::is_dynamic_value(K))
8584
return snippets::utils::get_dynamic_value<size_t>();
86-
if (with_repacking)
85+
if (with_repacking(type))
8786
return K;
8887
return K > 1024 ? 1024 : K > 512 ? 512 : K;
8988
};
9089
auto get_block_size_n = [&](const size_t N) -> size_t {
9190
// N blocking is disabled in dynamism by default
9291
if (ov::snippets::utils::is_dynamic_value(N))
9392
return snippets::utils::get_dynamic_value<size_t>();
94-
if (with_repacking)
93+
if (with_repacking(type))
9594
return N;
9695
return std::min<size_t>(64, N);
9796
};
@@ -105,13 +104,13 @@ bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir, const LinearIR:
105104
brgemm_expr->get_output_port_descriptor(0)->set_subtensor(ov::snippets::VectorDims{block_size_m, block_size_n});
106105

107106
ov::snippets::lowered::ExpressionPtr copy_b_expr = nullptr;
108-
if (with_repacking) {
107+
if (with_repacking(type)) {
109108
const auto copy_b = brgemm->get_brgemm_copy();
110109
copy_b_expr = linear_ir.get_expr_by_node(copy_b);
111110
const ov::snippets::VectorDims repacking_subtensor{block_size_k, block_size_n};
112111
copy_b_expr->get_input_port_descriptor(0)->set_subtensor(repacking_subtensor);
113112
copy_b_expr->get_output_port_descriptor(0)->set_subtensor(repacking_subtensor);
114-
if (ov::intel_cpu::brgemm_utils::with_compensations(type)) {
113+
if (with_compensations(type)) {
115114
const ov::snippets::VectorDims compensations_subtensor{1, block_size_n};
116115
OPENVINO_ASSERT(brgemm_expr->get_input_count() == 3, "Brgemm must have 3 inputs in case of compensations.");
117116
brgemm_expr->get_input_port_descriptor(2)->set_subtensor(compensations_subtensor);
@@ -126,7 +125,7 @@ bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir, const LinearIR:
126125

127126
const auto b_input_port = include_repacking && copy_b_expr ? copy_b_expr->get_input_port(0) : brgemm_expr->get_input_port(1);
128127
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true), LoopPort(b_input_port, false)};
129-
if (!include_repacking && ov::intel_cpu::brgemm_utils::with_compensations(type))
128+
if (!include_repacking && with_compensations(type))
130129
entries.emplace_back(brgemm_expr->get_input_port(2), false);
131130
const std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), true)};
132131
loop_manager->mark_loop(loop_begin_it, loop_end_it, m, block_size_m, 1, entries, exits);

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,8 +1105,7 @@ void Transformations::MainSnippets(void) {
11051105
},
11061106
snippets::pass::TokenizeSnippets);
11071107

1108-
CPU_SET_CALLBACK_COMMON(snippetsManager,
1109-
[this](const std::shared_ptr<const ov::Node>& n) -> bool {
1108+
auto mm_supports_transpose_b = [this](const std::shared_ptr<const ov::Node>& n) {
11101109
MAYBE_UNUSED(inferencePrecision);
11111110
const auto& b_shape = n->get_input_partial_shape(1);
11121111
if (b_shape.is_dynamic())
@@ -1137,6 +1136,11 @@ void Transformations::MainSnippets(void) {
11371136
return false;
11381137
#endif
11391138
return true;
1139+
};
1140+
1141+
CPU_SET_CALLBACK_COMMON(snippetsManager,
1142+
[&mm_supports_transpose_b](const std::shared_ptr<const ov::Node>& n) {
1143+
return mm_supports_transpose_b(n);
11401144
},
11411145
snippets::pass::ExplicitTransposeMatMulInputs);
11421146

0 commit comments

Comments
 (0)