Skip to content

[pass] Fix DCE to keep initializers that are inputs #2245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
29 changes: 22 additions & 7 deletions onnxscript/ir/passes/common/unused_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,25 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph
return count


def _maybe_remove_unused_initialized_inputs(model: ir.Model, do_remove: bool) -> None:
graph_outputs = model.graph.outputs
initializers = model.graph.initializers
graph_inputs = model.graph.inputs
unused_init_inputs = []
for i, inp in reversed(list(enumerate(graph_inputs))):
if inp.name in initializers and not (inp.uses() or inp in graph_outputs):
if do_remove:
del graph_inputs[i]
else:
unused_init_inputs.append(inp.name)
if unused_init_inputs:
logger.warning(
"RemoveUnusedNodesPass: Found unused initialized inputs %s,"
" consider turning `remove_initialized_inputs` on",
unused_init_inputs,
)


class RemoveUnusedNodesPass(ir.passes.InPlacePass):
"""Pass for removing unused nodes and initializers.

Expand All @@ -108,14 +127,10 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
count = _remove_unused_nodes_in_graph_like(model.graph)
graph_outputs = frozenset(model.graph.outputs)
initializers = model.graph.initializers
if self.remove_initialized_inputs:
graph_inputs = model.graph.inputs
for i, inp in reversed(list(enumerate(graph_inputs))):
if inp.name in initializers and not (inp in graph_outputs or inp.uses()):
del graph_inputs[i]
count += 1
graph_inputs = model.graph.inputs
_maybe_remove_unused_initialized_inputs(model, self.remove_initialized_inputs)
for init in list(initializers.values()):
if not (init in graph_outputs or init.uses()):
if not (init.uses() or init in graph_outputs or init in graph_inputs):
assert init.name is not None
del initializers[init.name]
count += 1
Expand Down
1 change: 1 addition & 0 deletions onnxscript/ir/passes/common/unused_removal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_unused_initialized_inputs_are_kept_by_default(self):
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "Mul")
self.assertEqual(len(model.graph.input), 2)
self.assertEqual(len(model.graph.initializer), 1)

@parameterized.parameterized.expand([True, False])
def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool):
Expand Down
Loading