Skip to content

Commit f407d47

Browse files
justinchubyCopilot
andauthored
[pass] Update DCE passes (#2257)
- Remove the `remove_initialized_inputs` option in dce because the contract of the pass it that it does not modify model signature. Fixed bugs where initializers are removed. Instead, users can use #2253 to remove the initialized inputs first. - Additionally updated RemoveUnusedOpsetsPass to always retain the default opset. --------- Co-authored-by: Copilot <[email protected]>
1 parent a8f56c2 commit f407d47

File tree

3 files changed

+20
-54
lines changed

3 files changed

+20
-54
lines changed

onnxscript/ir/passes/common/unused_removal.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,29 +93,20 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph
9393

9494

9595
class RemoveUnusedNodesPass(ir.passes.InPlacePass):
96-
"""Pass for removing unused nodes and initializers.
96+
"""Pass for removing unused nodes and initializers (dead code elimination).
9797
98-
Attributes:
99-
remove_initialized_inputs: When an unused initializer is simultaneously a graph input,
100-
remove that input as well. Note that this will change the model input signature.
98+
This pass does not modify the model signature (inputs and outputs). It ensures
99+
that unused nodes and initializers are removed while preserving the original
100+
contract of the model.
101101
"""
102102

103-
def __init__(self, remove_initialized_inputs: bool = False):
104-
super().__init__()
105-
self.remove_initialized_inputs = remove_initialized_inputs
106-
107103
def call(self, model: ir.Model) -> ir.passes.PassResult:
108104
count = _remove_unused_nodes_in_graph_like(model.graph)
109105
graph_outputs = frozenset(model.graph.outputs)
106+
graph_inputs = frozenset(model.graph.inputs)
110107
initializers = model.graph.initializers
111-
if self.remove_initialized_inputs:
112-
graph_inputs = model.graph.inputs
113-
for i, inp in reversed(list(enumerate(graph_inputs))):
114-
if inp.name in initializers and not (inp in graph_outputs or inp.uses()):
115-
del graph_inputs[i]
116-
count += 1
117108
for init in list(initializers.values()):
118-
if not (init in graph_outputs or init.uses()):
109+
if not (init.uses() or init in graph_outputs or init in graph_inputs):
119110
assert init.name is not None
120111
del initializers[init.name]
121112
count += 1
@@ -193,13 +184,13 @@ def _process_graph_like(
193184

194185
def call(self, model: ir.Model) -> ir.passes.PassResult:
195186
# Record domains of all functions
196-
used_domains = set()
187+
used_domains = {""} # By default always retain the onnx (default) domain
197188
for function in model.functions.values():
198189
used_domains.add(function.domain)
199190
modified = self._process_graph_like(model.graph, used_domains=used_domains)
200191

201192
if self.process_functions:
202193
for function in model.functions.values():
203-
modified |= self._process_graph_like(function, used_domains=set())
194+
modified |= self._process_graph_like(function, used_domains={""})
204195

205196
return ir.passes.PassResult(model, modified=modified)

onnxscript/ir/passes/common/unused_removal_test.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,13 @@
1313
class RemoveUnusedTest(unittest.TestCase):
1414
using_ir: bool
1515

16-
def remove_unused_nodes(
17-
self, model: onnx.ModelProto, remove_initialized_inputs: bool = False
18-
):
16+
def remove_unused_nodes(self, model: onnx.ModelProto):
1917
if self.using_ir:
2018
model_ir = ir.serde.deserialize_model(model)
21-
onnxscript.optimizer.remove_unused_nodes(model_ir, remove_initialized_inputs)
19+
onnxscript.optimizer.remove_unused_nodes(model_ir)
2220
model = ir.serde.serialize_model(model_ir)
2321
return model
24-
onnxscript.optimizer.remove_unused_nodes(model, remove_initialized_inputs)
22+
onnxscript.optimizer.remove_unused_nodes(model)
2523
return model
2624

2725
def test_remove_unused_nodes(self):
@@ -56,24 +54,7 @@ def test_remove_unused_initializers(self):
5654
self.assertEqual(model.graph.node[0].op_type, "Mul")
5755
self.assertEqual(len(model.graph.initializer), 0)
5856

59-
def test_unused_initialized_inputs_are_removed_when_requested(self):
60-
# https://github.com/microsoft/onnxscript/issues/2211
61-
model = onnx.parser.parse_model(
62-
"""
63-
<ir_version: 10, opset_import: [ "" : 17]>
64-
agraph (float[N] x, float[N] two) => (float[N] z)
65-
<float two = {2.0,2.0}> {
66-
four = Add(two, two)
67-
z = Mul(x, x)
68-
}
69-
"""
70-
)
71-
model = self.remove_unused_nodes(model, remove_initialized_inputs=True)
72-
self.assertEqual(len(model.graph.node), 1)
73-
self.assertEqual(model.graph.node[0].op_type, "Mul")
74-
self.assertEqual(len(model.graph.input), 1)
75-
76-
def test_unused_initialized_inputs_are_kept_by_default(self):
57+
def test_unused_initialized_inputs_are_kept(self):
7758
model = onnx.parser.parse_model(
7859
"""
7960
<ir_version: 10, opset_import: [ "" : 17]>
@@ -88,9 +69,9 @@ def test_unused_initialized_inputs_are_kept_by_default(self):
8869
self.assertEqual(len(model.graph.node), 1)
8970
self.assertEqual(model.graph.node[0].op_type, "Mul")
9071
self.assertEqual(len(model.graph.input), 2)
72+
self.assertEqual(len(model.graph.initializer), 1)
9173

92-
@parameterized.parameterized.expand([True, False])
93-
def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool):
74+
def test_unused_inputs_are_not_removed(self):
9475
# preserve inputs as part of interface
9576
model = onnx.parser.parse_model(
9677
"""
@@ -102,9 +83,7 @@ def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool):
10283
}
10384
"""
10485
)
105-
model = self.remove_unused_nodes(
106-
model, remove_initialized_inputs=remove_initialized_inputs
107-
)
86+
model = self.remove_unused_nodes(model)
10887
self.assertEqual(len(model.graph.node), 1)
10988
self.assertEqual(model.graph.node[0].op_type, "Mul")
11089
self.assertEqual(len(model.graph.input), 2)

onnxscript/optimizer/__init__.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,15 @@ def fold_constants(
112112
return result
113113

114114

115-
def remove_unused_nodes(
116-
model: ir.Model | onnx.ModelProto, remove_initialized_inputs: bool = False
117-
) -> None:
115+
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
118116
"""Removes unused nodes from a model inplace."""
119117
if isinstance(model, ir.Model):
120-
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(
121-
remove_initialized_inputs=remove_initialized_inputs
122-
)(model)
118+
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model)
123119
else:
124120
model_ir = ir.serde.deserialize_model(model)
125-
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(
126-
remove_initialized_inputs=remove_initialized_inputs
127-
)(model_ir).model
121+
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(
122+
model_ir
123+
).model
128124
new_proto = ir.serde.serialize_model(model_ir)
129125
model.Clear()
130126
model.CopyFrom(new_proto)

0 commit comments

Comments
 (0)