@@ -170,6 +170,40 @@ def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self):
170
170
self .assertEqual (model .graph .node [2 ].op_type , "LayerNormalization" )
171
171
self .assertEqual (len (model .graph .node [2 ].output ), 3 )
172
172
173
+ def test_remove_trailing_unused_optional_outputs_batchnorm (self ):
174
+ model = onnx .parser .parse_model (
175
+ """
176
+ <ir_version: 7, opset_import: [ "" : 17]>
177
+ agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z) {
178
+ z, mean_out, var_out = BatchNormalization <training_mode=1> (x, scale, B, mean, var)
179
+ }
180
+ """
181
+ )
182
+ self .assertEqual (len (model .graph .node [0 ].attribute ), 1 )
183
+ optimizer .remove_unused_nodes (model )
184
+ self .assertEqual (len (model .graph .node ), 1 )
185
+ self .assertEqual (model .graph .node [0 ].op_type , "BatchNormalization" )
186
+ # Check that both the mean/var outputs are removed, and training_mode attribute is removed.
187
+ self .assertEqual (len (model .graph .node [0 ].output ), 1 )
188
+ self .assertEqual (len (model .graph .node [0 ].attribute ), 0 )
189
+
190
+ def test_avoid_remove_used_optional_outputs_batchnorm (self ):
191
+ model = onnx .parser .parse_model (
192
+ """
193
+ <ir_version: 7, opset_import: [ "" : 17]>
194
+ agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z, float[3] mean_out) {
195
+ z, mean_out, var_out = BatchNormalization <training_mode=1> (x, scale, B, mean, var)
196
+ }
197
+ """
198
+ )
199
+ self .assertEqual (len (model .graph .node [0 ].attribute ), 1 )
200
+ optimizer .remove_unused_nodes (model )
201
+ self .assertEqual (len (model .graph .node ), 1 )
202
+ self .assertEqual (model .graph .node [0 ].op_type , "BatchNormalization" )
203
+ # Check that the mean/var outputs are NOT removed, and training_mode attribute is NOT removed.
204
+ self .assertEqual (len (model .graph .node [0 ].output ), 3 )
205
+ self .assertEqual (len (model .graph .node [0 ].attribute ), 1 )
206
+
173
207
174
208
if __name__ == "__main__" :
175
209
unittest .main ()
0 commit comments