Skip to content

Commit ae85619

Browse files
authored
Introduce new optimizer MatMul + BatchNormalization (#17915)
### Description Introduce new ORT L1 optimizer under RewriteRule category to fuse MatMul + BatchNormalization node. This optimizer look for a specific pattern observed in one of the impacting customer models and fuse the Matmul and Batchnormalization node into a Gemm node. For details on the pattern matching and fusion please refer to the comment section of `matmul_bn_fusion.cc`. To visualize, this optimizer will replace following subgraph to a Gemm node. <pre> MatMul GEMM | | Reshape ^ ---> Reshape ^ | | Transpose ^ Transpose ^ | BatchNormalization Note: ^ means there can be >=0 occurrence(s) of that node. Few example fusable pattern: * - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose * - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape * - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose * - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape * - MatMul -> BatchNormalization ---> GEMM </pre> Note: This optimizer may evolve in the future to be more generic in terms of the pattern matching. ### Motivation and Context - Why is this change required? What problem does it solve? One of the user of ORT+DML ep needs this to better target the model to DML. But this transformation applies more broadly, so added L1 optimizer. <!-- - If it fixes an open issue, please link to the issue here. -->
1 parent 76e275b commit ae85619

11 files changed

Lines changed: 543 additions & 9 deletions

onnxruntime/core/optimizer/graph_transformer_utils.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include "core/optimizer/matmul_integer_to_float.h"
5151
#include "core/optimizer/matmul_scale_fusion.h"
5252
#include "core/optimizer/matmul_transpose_fusion.h"
53+
#include "core/optimizer/matmul_bn_fusion.h"
5354
#include "core/optimizer/nchwc_transformer.h"
5455
#include "core/optimizer/noop_elimination.h"
5556
#include "core/optimizer/not_where_fusion.h"
@@ -127,6 +128,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
127128
rules.push_back(std::make_unique<ConvAddFusion>());
128129
rules.push_back(std::make_unique<ConvMulFusion>());
129130
rules.push_back(std::make_unique<ConvBNFusion>());
131+
rules.push_back(std::make_unique<MatmulBNFusion>());
130132
rules.push_back(std::make_unique<ClipQuantFusion>());
131133
rules.push_back(std::make_unique<ReluQuantFusion>());
132134
break;

onnxruntime/core/optimizer/initializer.cc

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,11 @@ Initializer& Initializer::sqrt() {
291291
namespace {
292292
template <typename T>
293293
struct ScaleByAxis {
294-
void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const {
294+
void operator()(Tensor& data,
295+
const Tensor& scalers,
296+
const size_t block_size,
297+
const size_t num_blocks,
298+
const bool column_major) const {
295299
ToNumeric<T> to_numeric;
296300
const auto scaler_size = scalers.Shape().Size();
297301
T* dst = data.MutableData<T>();
@@ -303,24 +307,32 @@ struct ScaleByAxis {
303307
}
304308
} else {
305309
for (size_t block_offset = 0, i = 0; i < num_blocks; i++) {
306-
const auto numeric_scaler = to_numeric(scalers_data[i]);
307-
for (size_t j = 0; j < block_size; ++j, ++block_offset) {
308-
dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler);
310+
if (column_major) {
311+
for (size_t j = 0; j < block_size; ++j, ++block_offset) {
312+
const auto numeric_scaler = to_numeric(scalers_data[j]);
313+
dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler);
314+
}
315+
} else {
316+
const auto numeric_scaler = to_numeric(scalers_data[i]);
317+
for (size_t j = 0; j < block_size; ++j, ++block_offset) {
318+
dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler);
319+
}
309320
}
310321
}
311322
}
312323
}
313324
};
314-
315325
} // namespace
316326

317-
void Initializer::scale_by_axis(const Initializer& scalers, int axis) {
327+
void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) {
318328
ORT_ENFORCE(axis >= 0, "Axis must be non-negative");
319329
const size_t block_size = narrow<size_t>(data_.Shape().SizeFromDimension(gsl::narrow_cast<size_t>(axis)));
320330
const size_t num_blocks = size() / block_size;
321-
ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size");
331+
ORT_ENFORCE(scalers.size() == 1 ||
332+
(column_major ? scalers.size() == block_size : scalers.size() == num_blocks),
333+
"Invalid other(scalers) size");
322334
utils::MLTypeCallDispatcher<MLFloat16, BFloat16, float, double, int32_t, int64_t> t_disp(data_.GetElementType());
323-
t_disp.Invoke<ScaleByAxis>(data_, scalers.data_, block_size, num_blocks);
335+
t_disp.Invoke<ScaleByAxis>(data_, scalers.data_, block_size, num_blocks, column_major);
324336
}
325337
#endif // ORT_EXTENDED_MINIMAL_BUILD
326338
} // namespace onnxruntime

onnxruntime/core/optimizer/initializer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class Initializer final {
8686

8787
Initializer& sqrt();
8888

89-
void scale_by_axis(const Initializer& other, int axis);
89+
void scale_by_axis(const Initializer& other, int axis, bool column_major = false);
9090
#endif // ORT_EXTENDED_MINIMAL_BUILD
9191
private:
9292
std::string name_;
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/optimizer/matmul_bn_fusion.h"
5+
#include "core/graph/graph_utils.h"
6+
#include "core/optimizer/initializer.h"
7+
#include "core/optimizer/utils.h"
8+
9+
namespace onnxruntime {
10+
11+
namespace {
12+
const std::vector<std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>>> ignorable_nodes{
13+
{"Reshape", {1, 5, 13, 14, 19}},
14+
{"Transpose", {1, 13}}};
15+
const std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}};
16+
} // namespace
17+
18+
bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) {
19+
const Node* curr_node = graph.GetNode(curr_node_index);
20+
21+
// curr_node has different execution provider then it's parent or
22+
// has output edge != 1 (this condition will handle the case when ignorable node
23+
// is graph output i.e. a graph like this "MatMul->Transpose")
24+
if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() ||
25+
curr_node->GetOutputEdgesCount() != 1) {
26+
return false;
27+
}
28+
29+
// curr_node can be any of the ignorable_nodes.
30+
for (size_t index = 0; index < ignorable_nodes.size(); index++) {
31+
if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) {
32+
return true;
33+
}
34+
}
35+
36+
return false;
37+
}
38+
39+
std::optional<NodeIndex> MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) {
40+
while (NodeIsIgnorable(graph, root_node, curr_node_index)) {
41+
curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index();
42+
}
43+
44+
// curr_node is neither ignorable nor dest
45+
const Node* curr_node = graph.GetNode(curr_node_index);
46+
if (curr_node->OpType() != dest.first) {
47+
return std::nullopt;
48+
}
49+
50+
if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() &&
51+
graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) {
52+
return curr_node_index;
53+
}
54+
55+
// either curr_node has different execution provider or
56+
// has invalid opset.
57+
return std::nullopt;
58+
}
59+
60+
/*
61+
* Given a MatMul node, it will verify the following pattern.
62+
* MatMul GEMM
63+
* | |
64+
* Reshape ^ ---> Reshape ^
65+
* | |
66+
* Transpose ^ Transpose ^
67+
* |
68+
* BatchNormalization
69+
* Note: ^ means there can be 0 or any occurrences of that node.
70+
* Few example fusable pattern:
71+
* - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose
72+
* - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape
73+
* - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose
74+
* - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape
75+
* - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape
76+
* - MatMul -> BatchNormalization ---> GEMM
77+
* Other Conditions:
78+
* - B tensor of MatMul should be constant.
79+
* - scale, B, mean, var tensors of BatchNormalization should be constant.
80+
* - Every node in the path, except the BatchNormalization, should have only 1 output edge.
81+
*/
82+
bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
83+
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) ||
84+
node.GetOutputEdgesCount() != 1) {
85+
return false;
86+
}
87+
88+
if (graph.NodeProducesGraphOutput(node)) {
89+
return false;
90+
}
91+
92+
// because <node> is not producing graph output, it means it will have a child node
93+
NodeIndex child_node_index = node.OutputNodesBegin()->Index();
94+
std::optional<NodeIndex> batch_norm_index = MatchPath(graph, node, child_node_index);
95+
if (!batch_norm_index.has_value()) {
96+
return false;
97+
}
98+
99+
const Node* batch_norm_node = graph.GetNode(*batch_norm_index);
100+
101+
// Check that the appropriate inputs to the Matmul and BN nodes are constants.
102+
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) ||
103+
!graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) ||
104+
!graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) ||
105+
!graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) ||
106+
!graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) {
107+
return false;
108+
}
109+
110+
// First output from BN is required. Others are optional. If any optional outputs exist we can't fuse.
111+
const auto& output_defs = batch_norm_node->OutputDefs();
112+
if (output_defs.size() > 1) {
113+
for (size_t i = 1, end = output_defs.size(); i < end; ++i) {
114+
if (output_defs[i] != nullptr && output_defs[i]->Exists()) {
115+
return false;
116+
}
117+
}
118+
}
119+
120+
return true;
121+
}
122+
123+
/*
124+
* BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc]
125+
* Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML
126+
* Expanding out the terms:
127+
* Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias
128+
* Here,
129+
* [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha`
130+
* [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta`
131+
* Output = alpha * Input + beta, Input = B tensor of MatMul.
132+
*
133+
*/
134+
Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
135+
NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index();
136+
NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value();
137+
138+
Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph
139+
140+
// only perform fusion if epsilon is present and is of float_32 type
141+
auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon");
142+
if (epsilon_attribute == batch_norm_node.GetAttributes().end() ||
143+
epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) {
144+
return Status::OK();
145+
}
146+
const float epsilon = epsilon_attribute->second.f();
147+
148+
const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name());
149+
ORT_ENFORCE(scale_tensor);
150+
const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name());
151+
ORT_ENFORCE(bias_tensor);
152+
const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name());
153+
ORT_ENFORCE(mean_tensor);
154+
const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name());
155+
ORT_ENFORCE(var_tensor);
156+
const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name());
157+
ORT_ENFORCE(matmul_b_tensor);
158+
159+
if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) ||
160+
!optimizer_utils::IsFloatingPointDataType(*scale_tensor) ||
161+
!optimizer_utils::IsFloatingPointDataType(*bias_tensor) ||
162+
!optimizer_utils::IsFloatingPointDataType(*mean_tensor) ||
163+
!optimizer_utils::IsFloatingPointDataType(*var_tensor) ||
164+
scale_tensor->dims_size() != 1 ||
165+
bias_tensor->dims_size() != 1 ||
166+
mean_tensor->dims_size() != 1 ||
167+
var_tensor->dims_size() != 1 ||
168+
scale_tensor->dims(0) != matmul_b_tensor->dims(1) ||
169+
bias_tensor->dims(0) != matmul_b_tensor->dims(1) ||
170+
mean_tensor->dims(0) != matmul_b_tensor->dims(1) ||
171+
var_tensor->dims(0) != matmul_b_tensor->dims(1)) {
172+
return Status::OK();
173+
}
174+
175+
/*
176+
* temp = scale / sqrt(var + epsilon)
177+
* output = (temp * Input) - ((temp * mean) + bias)
178+
*/
179+
Initializer scale(*scale_tensor, graph.ModelPath());
180+
Initializer bias(*bias_tensor, graph.ModelPath());
181+
Initializer mean(*mean_tensor, graph.ModelPath());
182+
Initializer var(*var_tensor, graph.ModelPath());
183+
Initializer matmul_b(*matmul_b_tensor, graph.ModelPath());
184+
185+
var.add(epsilon);
186+
var.sqrt();
187+
scale.div(var); // this is the temp
188+
matmul_b.scale_by_axis(scale, 1, true);
189+
190+
mean.mul(scale);
191+
bias.sub(mean);
192+
193+
// create B tensorProto for new Gemm node from <matmulB> initializer.
194+
ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor);
195+
matmul_b.ToProto(new_gemm_b_tensor);
196+
const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name());
197+
new_gemm_b_tensor.set_name(new_gemm_b_name);
198+
NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor);
199+
200+
// create bias tensorProto for new Gemm node from <bias> initializer.
201+
ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor);
202+
bias.ToProto(new_gemm_bias_tensor);
203+
const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias");
204+
new_gemm_bias_tensor.set_name(new_gemm_bias_name);
205+
NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor);
206+
207+
Node& gemm_node = graph.AddNode(
208+
graph.GenerateNodeArgName("MatMulBnFusion_Gemm"),
209+
"Gemm",
210+
"Generated from Matmul BatchNormalization fusion",
211+
{matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg},
212+
matmul_node.MutableOutputDefs(),
213+
nullptr,
214+
kOnnxDomain);
215+
216+
// Remove MatMul node.
217+
Node* node = graph.GetNode(matmul_node.Index());
218+
graph_utils::RemoveNodeOutputEdges(graph, *node);
219+
graph.RemoveNode(matmul_node.Index());
220+
221+
// Delete optional empty output defs.
222+
// Delete BatchNormalization node and update the input of the child of BatchNormalization
223+
batch_norm_node.MutableOutputDefs().resize(1);
224+
NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index();
225+
graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node);
226+
227+
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
228+
return Status::OK();
229+
}
230+
} // namespace onnxruntime
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/optimizer/rewrite_rule.h"
7+
8+
namespace onnxruntime {
9+
/*
10+
* This fusion submerges a BatchNormalization operator to it's super
11+
* precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition()
12+
* is true.
13+
*/
14+
class MatmulBNFusion : public RewriteRule {
15+
public:
16+
MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {}
17+
18+
std::vector<std::string> TargetOpTypes() const noexcept override {
19+
return {"MatMul"};
20+
}
21+
22+
private:
23+
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
24+
25+
Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
26+
};
27+
} // namespace onnxruntime

0 commit comments

Comments
 (0)