Skip to content

Commit 7dc0e0f

Browse files
committed
[Snippets] Support FP32/BF16/I8 matmuls with transpose_b=true via BrgemmCopyB
1 parent a3ed68a commit 7dc0e0f

File tree

56 files changed

+1033
-850
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1033
-850
lines changed

src/common/snippets/docs/mha_optimization_guide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ For enhancing the execution efficiency, blocking across the M, K, and N matmul d
123123

124124
### Blocking Parameters
125125

126-
The heuristics for determining the optimal block sizes can be found in [SetBrgemmCPUBlockingParams](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp).
126+
The heuristics for determining the optimal block sizes can be found in [BrgemmCPUBlocking](../../../plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp).
127127

128128
**Please note: Blocking by M dimension is shared between both Brgemms. Please see [SplitLoops](../include/snippets/lowered/pass/split_loops.hpp) lowered pass for the details.**
129129

@@ -141,7 +141,7 @@ Based on previously discussed information, we provide the following recommendati
141141
In local experiments, some transformations might be worth to change:
142142
- Disable [ExtractUnsupportedTransposes](#extractunsupportedtransposes) transformation in order to benchmark Snippets Transpose implementation.
143143
- Adjust [SplitDimensionM](#splitdimensionm) heuristics in order to benchmark another splitting, or disable the pass at all.
144-
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `SetBrgemmCPUBlockingParams`.
144+
3. [Blocking parameters](#blocking-parameters): adjust blocking heuristics in `BrgemmCPUBlocking`.
145145
- Please note that there are 2 Matmul nodes inside a single MHA, and each Matmul can have his own optimal K, N blocking params.
146146
M block is better to keep the same since the corresponding blocking loop is shared between both Matmuls.
147147
- For the BF16/INT8 blocking loops, 2 options are possible: blocking can be done only for Brgemm node, or for BrgemmCopyB repacking too.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "snippets/lowered/pass/pass.hpp"
8+
#include "snippets/lowered/specific_loop_iter_handlers.hpp"
9+
#include "snippets/op/brgemm.hpp"
10+
11+
namespace ov {
12+
namespace snippets {
13+
namespace lowered {
14+
namespace pass {
15+
16+
/**
17+
* @interface BrgemmBlockingBase
18+
* @brief Base class for Brgemm blocking loops markup
19+
* @ingroup snippets
20+
*/
21+
class BrgemmBlockingBase : public snippets::lowered::pass::RangedPass {
22+
public:
23+
OPENVINO_RTTI("BrgemmBlockingBase", "RangedPass")
24+
bool run(snippets::lowered::LinearIR& linear_ir,
25+
snippets::lowered::LinearIR::constExprIt begin,
26+
snippets::lowered::LinearIR::constExprIt end) override;
27+
28+
static snippets::lowered::SpecificIterationHandlers get_default_blocking_loop_handlers(size_t work_amount, size_t block_size);
29+
30+
protected:
31+
/**
32+
* @interface mark_blocking_loops
33+
* @brief Covers brgemm with blocking loops. Also should calculate optimal blocking parameters inside.
34+
* @param linear_ir LIR that contains brgemm
35+
* @param brgemm_it iterator on brgemm expression which should be covered with blocking loops
36+
*/
37+
virtual bool mark_blocking_loops(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it) = 0;
38+
39+
static bool blocking_loop_exists(const snippets::lowered::LoopManagerPtr& loop_manager,
40+
const ov::snippets::lowered::ExpressionPtr& brgemm_expr,
41+
const std::shared_ptr<ov::snippets::op::Brgemm>& brgemm);
42+
};
43+
44+
} // namespace pass
45+
} // namespace lowered
46+
} // namespace snippets
47+
} // namespace ov

src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@ class TransformInnerSplitLoop : public pass::RangedPass {
6464
size_t m_tail_size;
6565
};
6666

67+
/**
68+
* @interface SetEvaluateOnce
69+
* @brief The pass set `evaluate once = true` only to ExpandedLoopInfo which is mapped on LoopEnd in the passed iterator `end`.
70+
* The pointer arithmetic should be updated in the separate optimization `OptimizeLoopSingleEvaluation`
71+
* @ingroup snippets
72+
*/
73+
class SetEvaluateOnce : public snippets::lowered::pass::RangedPass {
74+
public:
75+
SetEvaluateOnce() = default;
76+
OPENVINO_RTTI("SetEvaluateOnce", "RangedPass")
77+
bool run(snippets::lowered::LinearIR& linear_ir,
78+
snippets::lowered::LinearIR::constExprIt begin,
79+
snippets::lowered::LinearIR::constExprIt end) override;
80+
std::shared_ptr<snippets::lowered::pass::PassBase> merge(const std::shared_ptr<snippets::lowered::pass::PassBase>& other) override;
81+
};
6782
} // namespace pass
6883
} // namespace lowered
6984
} // namespace snippets

src/common/snippets/include/snippets/op/brgemm.hpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,17 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op {
2222
OPENVINO_OP("Brgemm", "SnippetsOpset");
2323
Brgemm(const Output<Node>& A, const Output<Node>& B,
2424
const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu,
25-
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
26-
size_t blk_size_m = 0, size_t blk_size_k = 0, size_t blk_size_n = 0);
25+
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
2726
Brgemm(const Output<Node>& A, const Output<Node>& B,
2827
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c,
29-
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
30-
size_t blk_size_m = 0, size_t blk_size_k = 0, size_t blk_size_n = 0);
28+
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
3129
Brgemm() = default;
3230

3331
size_t get_offset_a() const { return get_input_offset(0); }
3432
size_t get_offset_b() const { return get_input_offset(1); }
3533
size_t get_offset_c() const { return get_output_offset(0); }
3634

37-
size_t get_m_block_size() const { return m_M_blk; }
38-
size_t get_k_block_size() const { return m_K_blk; }
39-
size_t get_n_block_size() const { return m_N_blk; }
4035
float get_beta() const { return m_beta; }
41-
42-
void set_m_block_size(size_t block_size) { m_M_blk = block_size; }
43-
void set_k_block_size(size_t block_size) { m_K_blk = block_size; }
44-
void set_n_block_size(size_t block_size) { m_N_blk = block_size; }
4536
void set_beta(float beta) { m_beta = beta; }
4637

4738
static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);
@@ -57,10 +48,6 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op {
5748
std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const;
5849
ov::PartialShape infer_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;
5950
ov::PartialShape get_planar_output_shape(const ov::PartialShape& output_shape) const;
60-
void set_block_size_values(size_t blk_size_m, size_t blk_size_k, size_t blk_size_n);
61-
size_t m_M_blk = 0;
62-
size_t m_K_blk = 0;
63-
size_t m_N_blk = 0;
6451
float m_beta = 0.f;
6552

6653
private:

src/common/snippets/include/snippets/pass/matmul_to_brgemm.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@ class MatMulToBrgemm: public ov::pass::MatcherPass {
2222
public:
2323
OPENVINO_RTTI("MatMulToBrgemm", "0");
2424
MatMulToBrgemm();
25-
26-
private:
27-
void init_ports(const std::shared_ptr<op::Brgemm>& brgemm) const;
2825
};
2926

3027

src/common/snippets/include/snippets/utils/utils.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ bool broadcast_merge_dim(size_t& dst, const size_t& d1, const size_t& d2);
127127
VectorDims pshape_to_vdims(const PartialShape&);
128128
ov::PartialShape vdims_to_pshape(const VectorDims&);
129129

130+
inline size_t dimension_to_size_t(const ov::Dimension& dim) {
131+
return dim.is_dynamic() ? snippets::utils::get_dynamic_value<VectorDims::value_type>() : static_cast<size_t>(dim.get_length());
132+
}
133+
130134
// dim_idx starts from the layout end: dim_idx = 0 -> last element in layout (layout.back())
131135
inline size_t get_input_dim_idx(const std::vector<size_t>& layout, size_t dim_idx) {
132136
OPENVINO_ASSERT(dim_idx < layout.size(), "Incorrect dim_idx");
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "snippets/lowered/pass/brgemm_blocking.hpp"
6+
7+
#include "snippets/itt.hpp"
8+
#include "snippets/lowered/linear_ir.hpp"
9+
#include "snippets/lowered/loop_manager.hpp"
10+
#include "snippets/lowered/pass/pass.hpp"
11+
#include "snippets/lowered/pass/propagate_subtensors.hpp"
12+
#include "snippets/lowered/pass/iter_handler.hpp"
13+
#include "snippets/snippets_isa.hpp"
14+
#include "snippets/utils/utils.hpp"
15+
16+
namespace ov {
17+
namespace snippets {
18+
namespace lowered {
19+
namespace pass {
20+
21+
snippets::lowered::SpecificIterationHandlers BrgemmBlockingBase::get_default_blocking_loop_handlers(size_t work_amount, size_t block_size) {
22+
SpecificIterationHandlers handlers;
23+
const auto tail_size = snippets::utils::is_dynamic_value(work_amount) ? snippets::utils::get_dynamic_value<size_t>() : work_amount % block_size;
24+
if (tail_size != 0)
25+
handlers.register_pass<snippets::lowered::SpecificLoopIterType::LAST_ITER, snippets::lowered::pass::UpdateSubtensors>(tail_size);
26+
handlers.register_pass<snippets::lowered::SpecificLoopIterType::LAST_ITER, SetEvaluateOnce>();
27+
return handlers;
28+
}
29+
30+
bool BrgemmBlockingBase::blocking_loop_exists(const snippets::lowered::LoopManagerPtr& loop_manager,
31+
const ExpressionPtr& brgemm_expr,
32+
const std::shared_ptr<snippets::op::Brgemm>& brgemm) {
33+
auto check_port = [&](const LoopPort& p) {
34+
return p.expr_port->get_expr() == brgemm_expr && ov::snippets::utils::one_of(p.dim_idx, 0ul, 1ul);
35+
};
36+
37+
const auto& loop_ids = brgemm_expr->get_loop_ids();
38+
for (const auto& id : loop_ids) {
39+
const auto loop = loop_manager->get_loop_info(id);
40+
if (std::any_of(loop->get_input_ports().begin(), loop->get_input_ports().end(), check_port) ||
41+
std::any_of(loop->get_output_ports().begin(), loop->get_output_ports().end(), check_port)) {
42+
return true;
43+
}
44+
}
45+
return false;
46+
}
47+
48+
bool BrgemmBlockingBase::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
49+
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmCPUBlocking")
50+
const auto& loop_manager = linear_ir.get_loop_manager();
51+
bool modified = false;
52+
for (auto expr_it = begin; expr_it != end; expr_it++) {
53+
const auto& brgemm_expr = *expr_it;
54+
const auto& node = brgemm_expr->get_node();
55+
const auto brgemm = ov::as_type_ptr<ov::snippets::op::Brgemm>(node);
56+
if (!brgemm || blocking_loop_exists(loop_manager, brgemm_expr, brgemm))
57+
continue;
58+
modified = mark_blocking_loops(linear_ir, expr_it);
59+
}
60+
61+
return modified;
62+
}
63+
64+
} // namespace pass
65+
} // namespace lowered
66+
} // namespace snippets
67+
} // namespace ov

src/common/snippets/src/lowered/pass/iter_handler.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,18 @@ std::shared_ptr<pass::PassBase> TransformInnerSplitLoop::merge(const std::shared
142142
return merged_pass;
143143
}
144144

145+
bool SetEvaluateOnce::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
146+
const auto& loop_end = ov::as_type_ptr<snippets::op::LoopEnd>(end->get()->get_node());
147+
OPENVINO_ASSERT(loop_end, "SetEvaluateOnce expected LoopEnd node in iterator `end`.");
148+
const auto& loop_info = linear_ir.get_loop_manager()->get_loop_info<ov::snippets::lowered::ExpandedLoopInfo>(loop_end->get_id());
149+
loop_info->set_evaluate_once(true);
150+
return true;
151+
}
152+
153+
std::shared_ptr<snippets::lowered::pass::PassBase> SetEvaluateOnce::merge(const std::shared_ptr<snippets::lowered::pass::PassBase>& other) {
154+
return !other || ov::is_type<SetEvaluateOnce>(other) ? std::make_shared<SetEvaluateOnce>() : nullptr;
155+
}
156+
145157
} // namespace pass
146158
} // namespace lowered
147159
} // namespace snippets

src/common/snippets/src/op/brgemm.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,23 @@ std::vector<size_t> get_output_layout(const std::shared_ptr<const ov::Node>& n)
3232

3333
Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
3434
const size_t offset_a, const size_t offset_b, const size_t offset_c,
35-
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
36-
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
35+
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c)
3736
: MemoryAccess(std::set<size_t>{0, 1}, std::set<size_t>{0}), Op({A, B}) {
3837
set_output_size(1);
3938
set_input_offset(offset_a, 0);
4039
set_input_offset(offset_b, 1);
4140
set_output_offset(offset_c, 0);
42-
set_block_size_values(blk_size_m, blk_size_k, blk_size_n);
4341
custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c));
4442
}
4543

4644
Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
4745
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c,
48-
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
49-
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
46+
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c)
5047
: MemoryAccess(PortMap{{0, desc_a}, {1, desc_b}}, PortMap{{0, desc_c}}), Op({A, B}) {
5148
set_output_size(1);
52-
set_block_size_values(blk_size_m, blk_size_k, blk_size_n);
5349
custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c));
5450
}
5551

56-
void Brgemm::set_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) {
57-
m_M_blk = blk_size_m;
58-
m_K_blk = blk_size_k;
59-
m_N_blk = blk_size_n;
60-
}
61-
6252
void Brgemm::custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c) {
6353
INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types);
6454

@@ -90,9 +80,6 @@ std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args
9080
}
9181

9282
bool Brgemm::visit_attributes(AttributeVisitor& visitor) {
93-
visitor.on_attribute("blk_M", m_M_blk);
94-
visitor.on_attribute("blk_K", m_K_blk);
95-
visitor.on_attribute("blk_N", m_N_blk);
9683
visitor.on_attribute("beta", m_beta);
9784
return MemoryAccess::visit_attributes(visitor);
9885
}

src/common/snippets/src/pass/explicit_transpose_matmul_inputs.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
66

7-
#include "snippets/op/subgraph.hpp"
8-
#include "snippets/itt.hpp"
9-
7+
#include "openvino/core/rt_info.hpp"
108
#include "openvino/pass/pattern/matcher.hpp"
119
#include "openvino/pass/pattern/op/wrap_type.hpp"
12-
#include "openvino/core/rt_info.hpp"
10+
#include "snippets/itt.hpp"
11+
#include "snippets/op/subgraph.hpp"
12+
#include "snippets/pass/mha_tokenization.hpp"
1313

1414
bool ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(const std::shared_ptr<ov::Node>& node) {
1515
const auto inputs = node->inputs();
@@ -58,6 +58,7 @@ void ov::snippets::pass::ExplicitTransposeMatMulInputs::extract(const ov::Input<
5858
"ExplicitTransposeMatMulInputs expects Parameter with one consumer in cases when there isn't existing Transpose on input");
5959
// Extract Transpose from MatMul
6060
OPENVINO_ASSERT(input.get_partial_shape().rank().is_static(), "ExplicitTransposeMatMulInputs supports only static ranks of shapes");
61+
6162
const auto rank = input.get_partial_shape().size();
6263
std::vector<size_t> transpose_order(rank, 0);
6364
std::iota(transpose_order.begin(), transpose_order.end(), 0);
@@ -75,7 +76,7 @@ ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs
7576
auto m_matmul0 = std::make_shared<ov::op::v0::MatMul>(ov::pass::pattern::any_input(), ov::pass::pattern::any_input());
7677

7778
register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_matmul0, matcher_name),
78-
[=](ov::pass::pattern::Matcher &m) {
79+
[OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher &m) {
7980
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs")
8081
auto root = m.get_match_root();
8182
bool rewritten = false;
@@ -89,7 +90,8 @@ ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs
8990
matmul->set_transpose_a(false);
9091
rewritten |= true;
9192
}
92-
if (matmul->get_transpose_b()) {
93+
94+
if (matmul->get_transpose_b() && !transformation_callback(matmul)) {
9395
extract(matmul->input(1));
9496
matmul->set_transpose_b(false);
9597
rewritten |= true;

0 commit comments

Comments
 (0)