Skip to content

[pass] Update DCE passes #2257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 8 additions & 18 deletions onnxscript/ir/passes/common/unused_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,19 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph


class RemoveUnusedNodesPass(ir.passes.InPlacePass):
"""Pass for removing unused nodes and initializers.
"""Pass for removing unused nodes and initializers (dead code elimination).

Attributes:
remove_initialized_inputs: When an unused initializer is simultaneously a graph input,
remove that input as well. Note that this will change the model input signature.
This pass does not modify the model signature (inputs and outputs). It ensures
that unused nodes and initializers are removed while preserving the original
contract of the model.
"""

def __init__(self, remove_initialized_inputs: bool = False):
super().__init__()
self.remove_initialized_inputs = remove_initialized_inputs

def call(self, model: ir.Model) -> ir.passes.PassResult:
count = _remove_unused_nodes_in_graph_like(model.graph)
graph_outputs = frozenset(model.graph.outputs)
graph_inputs = frozenset(model.graph.inputs)
initializers = model.graph.initializers
if self.remove_initialized_inputs:
graph_inputs = model.graph.inputs
for i, inp in reversed(list(enumerate(graph_inputs))):
if inp.name in initializers and not (inp in graph_outputs or inp.uses()):
del graph_inputs[i]
count += 1
for init in list(initializers.values()):
if not (init in graph_outputs or init.uses()):
if not (init.uses() or init in graph_outputs or init in graph_inputs):
assert init.name is not None
del initializers[init.name]
count += 1
Expand Down Expand Up @@ -193,13 +183,13 @@ def _process_graph_like(

def call(self, model: ir.Model) -> ir.passes.PassResult:
# Record domains of all functions
used_domains = set()
used_domains = {""} # By default always retain the onnx (default) domain
for function in model.functions.values():
used_domains.add(function.domain)
modified = self._process_graph_like(model.graph, used_domains=used_domains)

if self.process_functions:
for function in model.functions.values():
modified |= self._process_graph_like(function, used_domains=set())
modified |= self._process_graph_like(function, used_domains={""})

return ir.passes.PassResult(model, modified=modified)
35 changes: 7 additions & 28 deletions onnxscript/ir/passes/common/unused_removal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
class RemoveUnusedTest(unittest.TestCase):
using_ir: bool

def remove_unused_nodes(
self, model: onnx.ModelProto, remove_initialized_inputs: bool = False
):
def remove_unused_nodes(self, model: onnx.ModelProto):
if self.using_ir:
model_ir = ir.serde.deserialize_model(model)
onnxscript.optimizer.remove_unused_nodes(model_ir, remove_initialized_inputs)
onnxscript.optimizer.remove_unused_nodes(model_ir)
model = ir.serde.serialize_model(model_ir)
return model
onnxscript.optimizer.remove_unused_nodes(model, remove_initialized_inputs)
onnxscript.optimizer.remove_unused_nodes(model)
return model

def test_remove_unused_nodes(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should a test with a model including subgraphs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - will update

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand Down Expand Up @@ -56,24 +54,7 @@ def test_remove_unused_initializers(self):
self.assertEqual(model.graph.node[0].op_type, "Mul")
self.assertEqual(len(model.graph.initializer), 0)

def test_unused_initialized_inputs_are_removed_when_requested(self):
# https://github.com/microsoft/onnxscript/issues/2211
model = onnx.parser.parse_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
agraph (float[N] x, float[N] two) => (float[N] z)
<float two = {2.0,2.0}> {
four = Add(two, two)
z = Mul(x, x)
}
"""
)
model = self.remove_unused_nodes(model, remove_initialized_inputs=True)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "Mul")
self.assertEqual(len(model.graph.input), 1)

def test_unused_initialized_inputs_are_kept_by_default(self):
def test_unused_initialized_inputs_are_kept(self):
model = onnx.parser.parse_model(
"""
<ir_version: 10, opset_import: [ "" : 17]>
Expand All @@ -88,9 +69,9 @@ def test_unused_initialized_inputs_are_kept_by_default(self):
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "Mul")
self.assertEqual(len(model.graph.input), 2)
self.assertEqual(len(model.graph.initializer), 1)

@parameterized.parameterized.expand([True, False])
def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool):
def test_unused_inputs_are_not_removed(self):
# preserve inputs as part of interface
model = onnx.parser.parse_model(
"""
Expand All @@ -102,9 +83,7 @@ def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool):
}
"""
)
model = self.remove_unused_nodes(
model, remove_initialized_inputs=remove_initialized_inputs
)
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "Mul")
self.assertEqual(len(model.graph.input), 2)
Expand Down
14 changes: 5 additions & 9 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,15 @@ def fold_constants(
return result


def remove_unused_nodes(
model: ir.Model | onnx.ModelProto, remove_initialized_inputs: bool = False
) -> None:
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
"""Removes unused nodes from a model inplace."""
if isinstance(model, ir.Model):
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(
remove_initialized_inputs=remove_initialized_inputs
)(model)
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model)
else:
model_ir = ir.serde.deserialize_model(model)
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(
remove_initialized_inputs=remove_initialized_inputs
)(model_ir).model
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(
model_ir
).model
new_proto = ir.serde.serialize_model(model_ir)
model.Clear()
model.CopyFrom(new_proto)
Expand Down
Loading