Skip to content

Commit 87618e8

Browse files
authored
Update handling of batch-norm in DCE optimization (#1591)
Addresses Issue #1338
1 parent 1c154c9 commit 87618e8

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

onnxscript/optimizer/remove_unused.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,23 @@ def remove_unused_optional_outputs(
2626
op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain)
2727
except Exception:
2828
return
29-
# TODO: If current node is a BatchNormalization node,
30-
# based on training_mode atrribute, number of optional outputs and
31-
# how they are handled varies, handle both training_modes
29+
3230
if n.op_type == "BatchNormalization":
33-
return
31+
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
32+
# If running_mean and running_var are not used, remove them, and the training_mode attribute
33+
def is_used_output(i: int) -> bool:
34+
if i < len(n.output):
35+
return n.output[i] in used
36+
return False
37+
38+
if is_used_output(1) or is_used_output(2):
39+
return
40+
del n.output[1:]
41+
for j, attr in enumerate(n.attribute):
42+
if attr.name == "training_mode":
43+
del n.attribute[j]
44+
break
45+
3446
optional_info = []
3547
for o in op_schema.outputs:
3648
# Current ops do not have optional outputs if they have variable number of outputs

onnxscript/optimizer/remove_unused_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,40 @@ def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self):
170170
self.assertEqual(model.graph.node[2].op_type, "LayerNormalization")
171171
self.assertEqual(len(model.graph.node[2].output), 3)
172172

173+
def test_remove_trailing_unused_optional_outputs_batchnorm(self):
174+
model = onnx.parser.parse_model(
175+
"""
176+
<ir_version: 7, opset_import: [ "" : 17]>
177+
agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z) {
178+
z, mean_out, var_out = BatchNormalization <training_mode=1> (x, scale, B, mean, var)
179+
}
180+
"""
181+
)
182+
self.assertEqual(len(model.graph.node[0].attribute), 1)
183+
optimizer.remove_unused_nodes(model)
184+
self.assertEqual(len(model.graph.node), 1)
185+
self.assertEqual(model.graph.node[0].op_type, "BatchNormalization")
186+
# Check that both the mean/var outputs are removed, and training_mode attribute is removed.
187+
self.assertEqual(len(model.graph.node[0].output), 1)
188+
self.assertEqual(len(model.graph.node[0].attribute), 0)
189+
190+
def test_avoid_remove_used_optional_outputs_batchnorm(self):
191+
model = onnx.parser.parse_model(
192+
"""
193+
<ir_version: 7, opset_import: [ "" : 17]>
194+
agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z, float[3] mean_out) {
195+
z, mean_out, var_out = BatchNormalization <training_mode=1> (x, scale, B, mean, var)
196+
}
197+
"""
198+
)
199+
self.assertEqual(len(model.graph.node[0].attribute), 1)
200+
optimizer.remove_unused_nodes(model)
201+
self.assertEqual(len(model.graph.node), 1)
202+
self.assertEqual(model.graph.node[0].op_type, "BatchNormalization")
203+
# Check that the mean/var outputs are NOT removed, and training_mode attribute is NOT removed.
204+
self.assertEqual(len(model.graph.node[0].output), 3)
205+
self.assertEqual(len(model.graph.node[0].attribute), 1)
206+
173207

174208
if __name__ == "__main__":
175209
unittest.main()

0 commit comments

Comments
 (0)