Skip to content

[Passes] Consolidate DCE passes into common passes #2143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/ir/passes/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
179 changes: 179 additions & 0 deletions onnxscript/ir/passes/common/unused_removal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

__all__ = [
"RemoveUnusedNodesPass",
"RemoveUnusedFunctionsPass",
"RemoveUnusedOpsetsPass",
]

import logging

import onnx

from onnxscript import ir

logger = logging.getLogger(__name__)


def _remove_unused_optional_outputs(
node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int
) -> None:
try:
if node.domain not in {"", "onnx.ai"}:
return
op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain)
except Exception:
return

if node.op_type == "BatchNormalization":
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
# If running_mean and running_var are not used, remove them, and the training_mode attribute
def is_used_output(i: int) -> bool:
if i < len(node.outputs):
val = node.outputs[i]
return val in graph_outputs or bool(val.uses())
return False

if is_used_output(1) or is_used_output(2):
return
if len(node.outputs) > 1:
node.outputs[1].name = ""
if len(node.outputs) > 2:
node.outputs[2].name = ""
node.attributes.pop("training_mode", None)
return

optional_info = []
for o in op_schema.outputs:
# Current ops do not have optional outputs if they have variable number of outputs
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
return
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
# If no optional outputs in spec, skip delete operations
if len([o == 1 for o in optional_info]) == 0:
return

for i, out in enumerate(node.outputs):
if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
out.name = ""


def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
graph_outputs = frozenset(function_or_graph.outputs)
onnx_opset_version = function_or_graph.opset_imports.get("", None)
count = 0
for node in reversed(function_or_graph):
removable = True
for output in node.outputs:
if output in graph_outputs or output.uses():
removable = False
break
if removable:
function_or_graph.remove(node, safe=True)
count += 1
else:
if onnx_opset_version is not None:
_remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
for attr in node.attributes.values():
if not isinstance(attr, ir.Attr):
continue
if attr.type == ir.AttributeType.GRAPH:
count += _remove_unused_nodes_in_graph_like(attr.as_graph())
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.as_graphs():
count += _remove_unused_nodes_in_graph_like(graph)
return count


class RemoveUnusedNodesPass(ir.passes.InPlacePass):
def call(self, model: ir.Model) -> ir.passes.PassResult:
count = _remove_unused_nodes_in_graph_like(model.graph)
graph_outputs = frozenset(model.graph.outputs)
initializers = model.graph.initializers
for init in list(initializers.values()):
if not (init in graph_outputs or init.uses()):
assert init.name is not None
del initializers[init.name]
count += 1
for function in model.functions.values():
count += _remove_unused_nodes_in_graph_like(function)
if count:
logger.info("Removed %s unused nodes", count)
return ir.passes.PassResult(model, modified=True)
return ir.passes.PassResult(model, modified=False)


class RemoveUnusedFunctionsPass(ir.passes.InPlacePass):
def __init__(self):
super().__init__()
self._used: set[ir.OperatorIdentifier] | None = None

def call(self, model: ir.Model) -> ir.passes.PassResult:
self._used = set()
for node in ir.traversal.RecursiveGraphIterator(model.graph):
self._call_node(model, node)

# Update the model to remove unused functions
unused = set(model.functions) - self._used
if not unused:
logger.info("No unused functions to remove")
return ir.passes.PassResult(model, modified=False)

for op_identifier in unused:
del model.functions[op_identifier]

logger.info("Removed %s unused functions", len(unused))
logger.debug("Functions left: %s", list(model.functions))
logger.debug("Functions removed: %s", unused)

self._used = None
return ir.passes.PassResult(model, modified=True)

def _call_function(self, model: ir.Model, function: ir.Function) -> None:
assert self._used is not None
if function.identifier() in self._used:
# The function and its nodes are already recorded as used
return
self._used.add(function.identifier())
for node in ir.traversal.RecursiveGraphIterator(function):
self._call_node(model, node)

def _call_node(self, model: ir.Model, node: ir.Node) -> None:
op_identifier = node.op_identifier()
if op_identifier not in model.functions:
return
self._call_function(model, model.functions[op_identifier])


class RemoveUnusedOpsetsPass(ir.passes.InPlacePass):
"""Remove unused opset imports from the model and functions.

Attributes:
process_functions: Whether to process functions in the model. If True, the pass will
remove unused opset imports from functions as well. If False, only the main graph
will be processed.
"""

def __init__(self, process_functions: bool = True):
super().__init__()
self.process_functions = process_functions

def _process_graph_like(self, graph_like: ir.Graph | ir.Function) -> bool:
used_domains: set[str] = set()
for node in ir.traversal.RecursiveGraphIterator(graph_like):
used_domains.add(node.domain)
unused = set(graph_like.opset_imports) - used_domains
for domain in unused:
del graph_like.opset_imports[domain]
return bool(unused)

def call(self, model: ir.Model) -> ir.passes.PassResult:
modified = self._process_graph_like(model.graph)

if self.process_functions:
for function in model.functions.values():
modified |= self._process_graph_like(function)

return ir.passes.PassResult(model, modified=modified)
26 changes: 25 additions & 1 deletion onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"inline",
]

import ir.passes.common.unused_removal
import onnx

import onnxscript.optimizer._constant_folding as constant_folding
Expand All @@ -20,7 +21,6 @@
from onnxscript import ir
from onnxscript.optimizer._inliner import inline
from onnxscript.optimizer._optimizer import optimize_ir
from onnxscript.optimizer._remove_unused import remove_unused_nodes

basic_constant_propagation = constant_folding.basic_constant_propagation
fold_constants_ir = constant_folding.fold_constants
Expand All @@ -40,3 +40,27 @@ def fold_constants(model: ir.Model | onnx.ModelProto, *args, **kwargs) -> bool:
return constant_folding.fold_constants(model, *args, **kwargs)
else:
return legacy_constant_folding.fold_constants(model, *args, **kwargs)


def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
"""Removes unused nodes from a model inplace."""
if isinstance(model, ir.Model):
ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model)
else:
model_ir = ir.serde.deserialize_model(model)
model_ir = ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model_ir).model
new_proto = ir.serde.serialize_model(model_ir)
model.Clear()
model.CopyFrom(new_proto)


def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None:
"""Removes unused functions from a model inplace."""
if isinstance(model, ir.Model):
ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()(model)
else:
model_ir = ir.serde.deserialize_model(model)
model_ir = ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()(model_ir).model
new_proto = ir.serde.serialize_model(model_ir)
model.Clear()
model.CopyFrom(new_proto)
7 changes: 3 additions & 4 deletions onnxscript/optimizer/_legacy/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import onnx
import onnx.shape_inference

import onnxscript.optimizer
from onnxscript import rewriter
from onnxscript.optimizer._legacy._simple_function_folding import (
inline_functions_with_unused_outputs,
inline_simple_functions,
)
from onnxscript.optimizer._legacy.constant_folding import fold_constants
from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES
from onnxscript.optimizer._remove_unused import remove_unused_nodes
from onnxscript.optimizer._remove_unused_function import remove_unused_functions

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,9 +70,9 @@ def optimize(
model, external_data_folder, onnx_shape_inference=onnx_shape_inference
)

remove_unused_nodes(model)
onnxscript.optimizer.remove_unused_nodes(model)
inline_simple_functions(model)
model = remove_unused_functions(model)
onnxscript.optimizer.remove_unused_functions(model)
inline_functions_with_unused_outputs(model)
# NOTE: This is general rewrite rules
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
Expand Down
21 changes: 14 additions & 7 deletions onnxscript/optimizer/_legacy/_simple_function_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@

import onnx

from onnxscript.optimizer import _remove_unused_function
from onnxscript import ir
from onnxscript.ir.passes.common import unused_removal
from onnxscript.optimizer._legacy import _simple_function_folding


def _remove_unused_functions(model_proto: onnx.ModelProto) -> onnx.ModelProto:
model = ir.serde.deserialize_model(model_proto)
model = unused_removal.RemoveUnusedFunctionsPass()(model).model
return ir.serde.serialize_model(model)


class SingleNodeFunctionFoldingTest(unittest.TestCase):
def test_fold_single_node_function(self):
model = onnx.parser.parse_model(
Expand All @@ -34,7 +41,7 @@ def test_fold_single_node_function(self):
)

_simple_function_folding.inline_simple_functions(model)
model = _remove_unused_function.remove_unused_functions(model)
model = _remove_unused_functions(model)

self.assertEqual(len(model.functions), 0)

Expand All @@ -61,7 +68,7 @@ def test_fold_single_node_function_ref_attr(self):
)

_simple_function_folding.inline_simple_functions(model)
model = _remove_unused_function.remove_unused_functions(model)
model = _remove_unused_functions(model)

self.assertEqual(len(model.functions), 0)
self.assertFalse(model.graph.node[0].attribute[0].ref_attr_name)
Expand Down Expand Up @@ -100,7 +107,7 @@ def test_fold_single_node_function_nested(self):
)

_simple_function_folding.inline_simple_functions(model)
model = _remove_unused_function.remove_unused_functions(model)
model = _remove_unused_functions(model)

self.assertEqual(len(model.functions), 1)
self.assertEqual(model.functions[0].node[0].op_type, "Concat")
Expand Down Expand Up @@ -129,7 +136,7 @@ def test_fold_single_node_function_create_new_nodes_with_correct_attributes(self
"""
)
_simple_function_folding.inline_simple_functions(model)
model = _remove_unused_function.remove_unused_functions(model)
model = _remove_unused_functions(model)
self.assertEqual(len(model.functions), 0)
self.assertEqual(len(model.graph.node), 3)
self.assertEqual(model.graph.node[0].attribute[0].i, 10)
Expand Down Expand Up @@ -172,7 +179,7 @@ def test_fold_nested_if_function_succeeds(self):
)

_simple_function_folding.inline_simple_functions(model)
model = _remove_unused_function.remove_unused_functions(model)
model = _remove_unused_functions(model)

self.assertEqual(len(model.functions), 0)
self.assertEqual(len(model.graph.node), 2)
Expand Down Expand Up @@ -213,7 +220,7 @@ def test_fold_function_with_unused_output(self):
)

_simple_function_folding.inline_functions_with_unused_outputs(model)
model = _remove_unused_function.remove_unused_functions(model)
model = _remove_unused_functions(model)
self.assertEqual(len(model.functions), 1)


Expand Down
5 changes: 3 additions & 2 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import logging

import onnxscript.optimizer
from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
from onnxscript.optimizer._remove_unused import remove_unused_nodes
from onnxscript.rewriter import (
broadcast_to_matmul,
cast_constant_of_shape,
Expand Down Expand Up @@ -51,6 +51,7 @@ def optimize_ir(
outer optimization loop if no change is detected in one iteration.
"""
del stop_if_no_change # Looks like rewriter doesn't support this yet.
# TODO(justinchuby): Update this to use a pass manager
_inliner.inline(model)
for _ in range(num_iterations):
_constant_folding.fold_constants(
Expand All @@ -60,4 +61,4 @@ def optimize_ir(
output_size_limit=output_size_limit,
)
rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
remove_unused_nodes(model)
onnxscript.optimizer.remove_unused_nodes(model)
Loading
Loading