Skip to content

[IR] Export all common passes in onnxscript.ir.passes.common #2270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions docs/ir/ir_api/ir_passes_common.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
# ir.passes.common

```{eval-rst}
.. currentmodule:: onnxscript
```

## Built-in passes

Built-in passes provided by the ONNX IR

```{eval-rst}
.. autosummary::
:toctree: generated
:template: classtemplate.rst
:nosignatures:
.. automodule:: onnxscript.ir.passes.common
:show-inheritance:
:members:
:undoc-members:
:exclude-members: call

ir.passes.common.unused_removal.RemoveUnusedNodesPass
ir.passes.common.unused_removal.RemoveUnusedFunctionsPass
ir.passes.common.unused_removal.RemoveUnusedOpsetsPass
ir.passes.common.inliner.InlinePass
ir.passes.common.topological_sort.TopologicalSortPass
ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass
ir.passes.common.shape_inference.ShapeInferencePass
ir.passes.common.onnx_checker.CheckerPass
ir.passes.common.clear_metadata_and_docstring.ClearMetadataAndDocStringPass
```
54 changes: 29 additions & 25 deletions onnxscript/ir/passes/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,35 @@
# Licensed under the MIT License.

__all__ = [
"clear_metadata_and_docstring",
"constant_manipulation",
"inliner",
"onnx_checker",
"shape_inference",
"topological_sort",
"unused_removal",
"AddInitializersToInputsPass",
"CheckerPass",
"ClearMetadataAndDocStringPass",
"InlinePass",
"LiftConstantsToInitializersPass",
"LiftSubgraphInitializersToMainGraphPass",
"RemoveInitializersFromInputsPass",
"RemoveUnusedFunctionsPass",
"RemoveUnusedNodesPass",
"RemoveUnusedOpsetsPass",
"ShapeInferencePass",
"TopologicalSortPass",
]

from onnxscript.ir.passes.common import (
clear_metadata_and_docstring,
constant_manipulation,
inliner,
onnx_checker,
shape_inference,
topological_sort,
unused_removal,
from onnxscript.ir.passes.common.clear_metadata_and_docstring import (
ClearMetadataAndDocStringPass,
)
from onnxscript.ir.passes.common.constant_manipulation import (
AddInitializersToInputsPass,
LiftConstantsToInitializersPass,
LiftSubgraphInitializersToMainGraphPass,
RemoveInitializersFromInputsPass,
)
from onnxscript.ir.passes.common.inliner import InlinePass
from onnxscript.ir.passes.common.onnx_checker import CheckerPass
from onnxscript.ir.passes.common.shape_inference import ShapeInferencePass
from onnxscript.ir.passes.common.topological_sort import TopologicalSortPass
from onnxscript.ir.passes.common.unused_removal import (
RemoveUnusedFunctionsPass,
RemoveUnusedNodesPass,
RemoveUnusedOpsetsPass,
)


def __set_module() -> None:
"""Set the module of all functions in this module to this public module."""
global_dict = globals()
for name in __all__:
global_dict[name].__module__ = __name__


__set_module()
2 changes: 2 additions & 0 deletions onnxscript/ir/passes/common/clear_metadata_and_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@


class ClearMetadataAndDocStringPass(ir.passes.InPlacePass):
"""Clear all metadata and docstring from the model, graphs, nodes, and functions."""

def call(self, model: ir.Model) -> ir.passes.PassResult:
# 0. TODO: Should we clean model metadata and docstring?

Expand Down
2 changes: 2 additions & 0 deletions onnxscript/ir/passes/common/inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class InlinePassResult(ir.passes.PassResult):


class InlinePass(ir.passes.InPlacePass):
"""Inline model local functions to the main graph and clear function definitions."""

def __init__(self) -> None:
super().__init__()
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
Expand Down
8 changes: 6 additions & 2 deletions onnxscript/ir/passes/common/onnx_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"CheckerPass",
]

from typing import Literal

import onnx

from onnxscript import ir
Expand All @@ -18,11 +20,13 @@ class CheckerPass(ir.passes.PassBase):
"""Run onnx checker on the model."""

@property
def in_place(self) -> bool:
def in_place(self) -> Literal[True]:
"""This pass does not create a new model."""
return True

@property
def changes_input(self) -> bool:
def changes_input(self) -> Literal[False]:
"""This pass does not change the input model."""
return False

def __init__(
Expand Down
17 changes: 6 additions & 11 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

import onnx

import onnxscript.ir.passes.common.inliner
import onnxscript.ir.passes.common.unused_removal
import onnxscript.ir.passes.common
import onnxscript.optimizer._constant_folding as constant_folding
from onnxscript import ir
from onnxscript.optimizer._constant_folding import (
Expand Down Expand Up @@ -91,7 +90,7 @@
def inline(model: ir.Model) -> None:
"""Inline all function calls (recursively) in the model."""
if model.functions:
onnxscript.ir.passes.common.inliner.InlinePass()(model)
onnxscript.ir.passes.common.InlinePass()(model)


def fold_constants(
Expand All @@ -115,12 +114,10 @@
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
"""Removes unused nodes from a model inplace."""
if isinstance(model, ir.Model):
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model)
onnxscript.ir.passes.common.RemoveUnusedNodesPass()(model)
else:
model_ir = ir.serde.deserialize_model(model)
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(
model_ir
).model
model_ir = onnxscript.ir.passes.common.RemoveUnusedNodesPass()(model_ir).model
new_proto = ir.serde.serialize_model(model_ir)
model.Clear()
model.CopyFrom(new_proto)
Expand All @@ -129,12 +126,10 @@
def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None:
"""Removes unused functions from a model inplace."""
if isinstance(model, ir.Model):
onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()(model)
onnxscript.ir.passes.common.RemoveUnusedFunctionsPass()(model)

Check warning on line 129 in onnxscript/optimizer/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/__init__.py#L129

Added line #L129 was not covered by tests
else:
model_ir = ir.serde.deserialize_model(model)
model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()(
model_ir
).model
model_ir = onnxscript.ir.passes.common.RemoveUnusedFunctionsPass()(model_ir).model

Check warning on line 132 in onnxscript/optimizer/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/__init__.py#L132

Added line #L132 was not covered by tests
new_proto = ir.serde.serialize_model(model_ir)
model.Clear()
model.CopyFrom(new_proto)
18 changes: 8 additions & 10 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

import logging

import onnxscript.ir.passes.common.constant_manipulation
import onnxscript.ir.passes.common.inliner
import onnxscript.ir.passes.common.unused_removal
import onnxscript.ir.passes.common
from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding

Expand Down Expand Up @@ -45,20 +43,20 @@ def optimize_ir(
output_size_limit=output_size_limit,
),
rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass(),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedOpsetsPass(),
onnxscript.ir.passes.common.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.RemoveUnusedFunctionsPass(),
onnxscript.ir.passes.common.RemoveUnusedOpsetsPass(),
],
steps=num_iterations,
early_stop=stop_if_no_change,
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftSubgraphInitializersToMainGraphPass(),
onnxscript.ir.passes.common.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.LiftConstantsToInitializersPass(),
onnxscript.ir.passes.common.LiftSubgraphInitializersToMainGraphPass(),
]
if inline:
# Inline all functions first before optimizing
passes = [onnxscript.ir.passes.common.inliner.InlinePass(), *passes]
passes = [onnxscript.ir.passes.common.InlinePass(), *passes]
optimizer_pass = ir.passes.Sequential(*passes)
assert optimizer_pass.in_place
result = optimizer_pass(model)
Expand Down
13 changes: 6 additions & 7 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@

import onnx

import onnxscript.ir.passes.common
from onnxscript import ir
from onnxscript.ir.passes.common import _c_api_utils
from onnxscript.ir.passes.common import inliner as _inliner
from onnxscript.ir.passes.common import unused_removal as _unused_removal
from onnxscript.version_converter import _version_converter

logger = logging.getLogger(__name__)
Expand All @@ -40,14 +39,14 @@ def __init__(self, target_version: int, fallback: bool = False) -> None:
self.target_version = target_version
self.fallback = fallback
self.convert_pass = ir.passes.Sequential(
_inliner.InlinePass(),
onnxscript.ir.passes.common.InlinePass(),
_ConvertVersionPassRequiresInline(
target_version=target_version,
fallback=fallback,
),
_unused_removal.RemoveUnusedNodesPass(),
_unused_removal.RemoveUnusedFunctionsPass(),
_unused_removal.RemoveUnusedOpsetsPass(),
onnxscript.ir.passes.common.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.RemoveUnusedFunctionsPass(),
onnxscript.ir.passes.common.RemoveUnusedOpsetsPass(),
)

def call(self, model: ir.Model) -> ir.passes.PassResult:
Expand Down Expand Up @@ -78,7 +77,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
if model.functions:
raise ValueError(
"The model contains functions. The version conversion pass does not support "
"functions. Please use `onnxscript.ir.passes.common.inliner.InlinePass` to inline the "
"functions. Please use `onnxscript.ir.passes.common.InlinePass` to inline the "
f"functions before applying this pass ({self.__class__.__name__})."
)
if "" in model.graph.opset_imports:
Expand Down
2 changes: 1 addition & 1 deletion tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import onnxscript
import onnxscript.evaluator
import onnxscript.ir.passes.common.unused_removal
import onnxscript.ir.passes.common
from onnxscript import ir
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from tests.function_libs.torch_lib import error_reproduction
Expand Down
Loading