1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include " core/optimizer/matmul_bn_fusion.h"
5+ #include " core/graph/graph_utils.h"
6+ #include " core/optimizer/initializer.h"
7+ #include " core/optimizer/utils.h"
8+
9+ namespace onnxruntime {
10+
11+ namespace {
12+ const std::vector<std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>>> ignorable_nodes{
13+ {" Reshape" , {1 , 5 , 13 , 14 , 19 }},
14+ {" Transpose" , {1 , 13 }}};
15+ const std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>> dest = {" BatchNormalization" , {1 , 6 , 7 , 9 , 14 , 15 }};
16+ } // namespace
17+
18+ bool NodeIsIgnorable (const Graph& graph, const Node& root_node, NodeIndex curr_node_index) {
19+ const Node* curr_node = graph.GetNode (curr_node_index);
20+
21+ // curr_node has different execution provider then it's parent or
22+ // has output edge != 1 (this condition will handle the case when ignorable node
23+ // is graph output i.e. a graph like this "MatMul->Transpose")
24+ if (curr_node->GetExecutionProviderType () != root_node.GetExecutionProviderType () ||
25+ curr_node->GetOutputEdgesCount () != 1 ) {
26+ return false ;
27+ }
28+
29+ // curr_node can be any of the ignorable_nodes.
30+ for (size_t index = 0 ; index < ignorable_nodes.size (); index++) {
31+ if (graph_utils::IsSupportedOptypeVersionAndDomain (*curr_node, ignorable_nodes[index].first , ignorable_nodes[index].second )) {
32+ return true ;
33+ }
34+ }
35+
36+ return false ;
37+ }
38+
39+ std::optional<NodeIndex> MatchPath (const Graph& graph, const Node& root_node, NodeIndex curr_node_index) {
40+ while (NodeIsIgnorable (graph, root_node, curr_node_index)) {
41+ curr_node_index = graph.GetNode (curr_node_index)->OutputNodesBegin ()->Index ();
42+ }
43+
44+ // curr_node is neither ignorable nor dest
45+ const Node* curr_node = graph.GetNode (curr_node_index);
46+ if (curr_node->OpType () != dest.first ) {
47+ return std::nullopt ;
48+ }
49+
50+ if (curr_node->GetExecutionProviderType () == root_node.GetExecutionProviderType () &&
51+ graph_utils::IsSupportedOptypeVersionAndDomain (*curr_node, dest.first , dest.second )) {
52+ return curr_node_index;
53+ }
54+
55+ // either curr_node has different execution provider or
56+ // has invalid opset.
57+ return std::nullopt ;
58+ }
59+
60+ /*
61+ * Given a MatMul node, it will verify the following pattern.
62+ * MatMul GEMM
63+ * | |
64+ * Reshape ^ ---> Reshape ^
65+ * | |
66+ * Transpose ^ Transpose ^
67+ * |
68+ * BatchNormalization
69+ * Note: ^ means there can be 0 or any occurrences of that node.
70+ * Few example fusable pattern:
71+ * - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose
72+ * - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape
73+ * - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose
74+ * - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape
75+ * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape
76+ * - MatMul -> BatchNormalization ---> GEMM
77+ * Other Conditions:
78+ * - B tensor of MatMul should be constant.
79+ * - scale, B, mean, var tensors of BatchNormalization should be constant.
80+ * - Every node in the path, except the BatchNormalization, should have only 1 output edge.
81+ */
82+ bool MatmulBNFusion::SatisfyCondition (const Graph& graph, const Node& node, const logging::Logger&) const {
83+ if (!graph_utils::IsSupportedOptypeVersionAndDomain (node, " MatMul" , {1 , 9 , 13 }) ||
84+ node.GetOutputEdgesCount () != 1 ) {
85+ return false ;
86+ }
87+
88+ if (graph.NodeProducesGraphOutput (node)) {
89+ return false ;
90+ }
91+
92+ // because <node> is not producing graph output, it means it will have a child node
93+ NodeIndex child_node_index = node.OutputNodesBegin ()->Index ();
94+ std::optional<NodeIndex> batch_norm_index = MatchPath (graph, node, child_node_index);
95+ if (!batch_norm_index.has_value ()) {
96+ return false ;
97+ }
98+
99+ const Node* batch_norm_node = graph.GetNode (*batch_norm_index);
100+
101+ // Check that the appropriate inputs to the Matmul and BN nodes are constants.
102+ if (!graph_utils::NodeArgIsConstant (graph, *node.InputDefs ()[1 ]) ||
103+ !graph_utils::NodeArgIsConstant (graph, *batch_norm_node->InputDefs ()[1 ]) ||
104+ !graph_utils::NodeArgIsConstant (graph, *batch_norm_node->InputDefs ()[2 ]) ||
105+ !graph_utils::NodeArgIsConstant (graph, *batch_norm_node->InputDefs ()[3 ]) ||
106+ !graph_utils::NodeArgIsConstant (graph, *batch_norm_node->InputDefs ()[4 ])) {
107+ return false ;
108+ }
109+
110+ // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse.
111+ const auto & output_defs = batch_norm_node->OutputDefs ();
112+ if (output_defs.size () > 1 ) {
113+ for (size_t i = 1 , end = output_defs.size (); i < end; ++i) {
114+ if (output_defs[i] != nullptr && output_defs[i]->Exists ()) {
115+ return false ;
116+ }
117+ }
118+ }
119+
120+ return true ;
121+ }
122+
123+ /*
124+ * BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc]
125+ * Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML
126+ * Expanding out the terms:
127+ * Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias
128+ * Here,
129+ * [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha`
130+ * [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta`
131+ * Output = alpha * Input + beta, Input = B tensor of MatMul.
132+ *
133+ */
134+ Status MatmulBNFusion::Apply (Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
135+ NodeIndex child_node_index = matmul_node.OutputNodesBegin ()->Index ();
136+ NodeIndex batch_norm_node_index = MatchPath (graph, matmul_node, child_node_index).value ();
137+
138+ Node& batch_norm_node = *graph.GetNode (batch_norm_node_index); // need mutable node, that's why extracting node from graph
139+
140+ // only perform fusion if epsilon is present and is of float_32 type
141+ auto epsilon_attribute = batch_norm_node.GetAttributes ().find (" epsilon" );
142+ if (epsilon_attribute == batch_norm_node.GetAttributes ().end () ||
143+ epsilon_attribute->second .type () != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) {
144+ return Status::OK ();
145+ }
146+ const float epsilon = epsilon_attribute->second .f ();
147+
148+ const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer (graph, batch_norm_node.InputDefs ()[1 ]->Name ());
149+ ORT_ENFORCE (scale_tensor);
150+ const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer (graph, batch_norm_node.InputDefs ()[2 ]->Name ());
151+ ORT_ENFORCE (bias_tensor);
152+ const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer (graph, batch_norm_node.InputDefs ()[3 ]->Name ());
153+ ORT_ENFORCE (mean_tensor);
154+ const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer (graph, batch_norm_node.InputDefs ()[4 ]->Name ());
155+ ORT_ENFORCE (var_tensor);
156+ const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer (graph, matmul_node.InputDefs ()[1 ]->Name ());
157+ ORT_ENFORCE (matmul_b_tensor);
158+
159+ if (!optimizer_utils::IsFloatingPointDataType (*matmul_b_tensor) ||
160+ !optimizer_utils::IsFloatingPointDataType (*scale_tensor) ||
161+ !optimizer_utils::IsFloatingPointDataType (*bias_tensor) ||
162+ !optimizer_utils::IsFloatingPointDataType (*mean_tensor) ||
163+ !optimizer_utils::IsFloatingPointDataType (*var_tensor) ||
164+ scale_tensor->dims_size () != 1 ||
165+ bias_tensor->dims_size () != 1 ||
166+ mean_tensor->dims_size () != 1 ||
167+ var_tensor->dims_size () != 1 ||
168+ scale_tensor->dims (0 ) != matmul_b_tensor->dims (1 ) ||
169+ bias_tensor->dims (0 ) != matmul_b_tensor->dims (1 ) ||
170+ mean_tensor->dims (0 ) != matmul_b_tensor->dims (1 ) ||
171+ var_tensor->dims (0 ) != matmul_b_tensor->dims (1 )) {
172+ return Status::OK ();
173+ }
174+
175+ /*
176+ * temp = scale / sqrt(var + epsilon)
177+ * output = (temp * Input) - ((temp * mean) + bias)
178+ */
179+ Initializer scale (*scale_tensor, graph.ModelPath ());
180+ Initializer bias (*bias_tensor, graph.ModelPath ());
181+ Initializer mean (*mean_tensor, graph.ModelPath ());
182+ Initializer var (*var_tensor, graph.ModelPath ());
183+ Initializer matmul_b (*matmul_b_tensor, graph.ModelPath ());
184+
185+ var.add (epsilon);
186+ var.sqrt ();
187+ scale.div (var); // this is the temp
188+ matmul_b.scale_by_axis (scale, 1 , true );
189+
190+ mean.mul (scale);
191+ bias.sub (mean);
192+
193+ // create B tensorProto for new Gemm node from <matmulB> initializer.
194+ ONNX_NAMESPACE::TensorProto new_gemm_b_tensor (*matmul_b_tensor);
195+ matmul_b.ToProto (new_gemm_b_tensor);
196+ const std::string new_gemm_b_name = graph.GenerateNodeArgName (" MatMulBnFusion_GemmB_" + matmul_b_tensor->name ());
197+ new_gemm_b_tensor.set_name (new_gemm_b_name);
198+ NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer (graph, new_gemm_b_tensor);
199+
200+ // create bias tensorProto for new Gemm node from <bias> initializer.
201+ ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor (*bias_tensor);
202+ bias.ToProto (new_gemm_bias_tensor);
203+ const std::string new_gemm_bias_name = graph.GenerateNodeArgName (" MatMulBnFusion_GemmBias" );
204+ new_gemm_bias_tensor.set_name (new_gemm_bias_name);
205+ NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer (graph, new_gemm_bias_tensor);
206+
207+ Node& gemm_node = graph.AddNode (
208+ graph.GenerateNodeArgName (" MatMulBnFusion_Gemm" ),
209+ " Gemm" ,
210+ " Generated from Matmul BatchNormalization fusion" ,
211+ {matmul_node.MutableInputDefs ()[0 ], &new_gemm_b_node_arg, &new_gemm_bias_node_arg},
212+ matmul_node.MutableOutputDefs (),
213+ nullptr ,
214+ kOnnxDomain );
215+
216+ // Remove MatMul node.
217+ Node* node = graph.GetNode (matmul_node.Index ());
218+ graph_utils::RemoveNodeOutputEdges (graph, *node);
219+ graph.RemoveNode (matmul_node.Index ());
220+
221+ // Delete optional empty output defs.
222+ // Delete BatchNormalization node and update the input of the child of BatchNormalization
223+ batch_norm_node.MutableOutputDefs ().resize (1 );
224+ NodeIndex batch_norm_parent_index = graph.GetNode (child_node_index)->OpType () == " BatchNormalization" ? gemm_node.Index () : batch_norm_node.InputNodesBegin ()->Index ();
225+ graph_utils::FinalizeNodeFusion (graph, *graph.GetNode (batch_norm_parent_index), batch_norm_node);
226+
227+ rule_effect = RewriteRuleEffect::kRemovedCurrentNode ;
228+ return Status::OK ();
229+ }
230+ } // namespace onnxruntime
0 commit comments