diff --git a/docs/ir/ir_api/ir_passes_common.md b/docs/ir/ir_api/ir_passes_common.md index 695dc21950..37740160ce 100644 --- a/docs/ir/ir_api/ir_passes_common.md +++ b/docs/ir/ir_api/ir_passes_common.md @@ -1,25 +1,12 @@ # ir.passes.common -```{eval-rst} -.. currentmodule:: onnxscript -``` - -## Built-in passes - +Built-in passes provided by the ONNX IR ```{eval-rst} -.. autosummary:: - :toctree: generated - :template: classtemplate.rst - :nosignatures: +.. automodule:: onnxscript.ir.passes.common + :show-inheritance: + :members: + :undoc-members: + :exclude-members: call - ir.passes.common.unused_removal.RemoveUnusedNodesPass - ir.passes.common.unused_removal.RemoveUnusedFunctionsPass - ir.passes.common.unused_removal.RemoveUnusedOpsetsPass - ir.passes.common.inliner.InlinePass - ir.passes.common.topological_sort.TopologicalSortPass - ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass - ir.passes.common.shape_inference.ShapeInferencePass - ir.passes.common.onnx_checker.CheckerPass - ir.passes.common.clear_metadata_and_docstring.ClearMetadataAndDocStringPass ``` diff --git a/onnxscript/ir/passes/common/__init__.py b/onnxscript/ir/passes/common/__init__.py index c211572fd4..d1b4f176a2 100644 --- a/onnxscript/ir/passes/common/__init__.py +++ b/onnxscript/ir/passes/common/__init__.py @@ -2,31 +2,35 @@ # Licensed under the MIT License. __all__ = [ - "clear_metadata_and_docstring", - "constant_manipulation", - "inliner", - "onnx_checker", - "shape_inference", - "topological_sort", - "unused_removal", + "AddInitializersToInputsPass", + "CheckerPass", + "ClearMetadataAndDocStringPass", + "InlinePass", + "LiftConstantsToInitializersPass", + "LiftSubgraphInitializersToMainGraphPass", + "RemoveInitializersFromInputsPass", + "RemoveUnusedFunctionsPass", + "RemoveUnusedNodesPass", + "RemoveUnusedOpsetsPass", + "ShapeInferencePass", + "TopologicalSortPass", ] -from onnxscript.ir.passes.common import ( - clear_metadata_and_docstring, - constant_manipulation, - inliner, - onnx_checker, - shape_inference, - topological_sort, - unused_removal, +from onnxscript.ir.passes.common.clear_metadata_and_docstring import ( + ClearMetadataAndDocStringPass, +) +from onnxscript.ir.passes.common.constant_manipulation import ( + AddInitializersToInputsPass, + LiftConstantsToInitializersPass, + LiftSubgraphInitializersToMainGraphPass, + RemoveInitializersFromInputsPass, +) +from onnxscript.ir.passes.common.inliner import InlinePass +from onnxscript.ir.passes.common.onnx_checker import CheckerPass +from onnxscript.ir.passes.common.shape_inference import ShapeInferencePass +from onnxscript.ir.passes.common.topological_sort import TopologicalSortPass +from onnxscript.ir.passes.common.unused_removal import ( + RemoveUnusedFunctionsPass, + RemoveUnusedNodesPass, + RemoveUnusedOpsetsPass, ) - - -def __set_module() -> None: - """Set the module of all functions in this module to this public module.""" - global_dict = globals() - for name in __all__: - global_dict[name].__module__ = __name__ - - -__set_module() diff --git a/onnxscript/ir/passes/common/clear_metadata_and_docstring.py b/onnxscript/ir/passes/common/clear_metadata_and_docstring.py index f23787b6f6..0c1fa48cb0 100644 --- a/onnxscript/ir/passes/common/clear_metadata_and_docstring.py +++ b/onnxscript/ir/passes/common/clear_metadata_and_docstring.py @@ -16,6 +16,8 @@ class ClearMetadataAndDocStringPass(ir.passes.InPlacePass): + """Clear all metadata and docstring from the model, graphs, nodes, and functions.""" + def call(self, model: ir.Model) -> ir.passes.PassResult: # 0. TODO: Should we clean model metadata and docstring? diff --git a/onnxscript/ir/passes/common/inliner.py b/onnxscript/ir/passes/common/inliner.py index 5cefc94268..3a4f97a8a7 100644 --- a/onnxscript/ir/passes/common/inliner.py +++ b/onnxscript/ir/passes/common/inliner.py @@ -198,6 +198,8 @@ class InlinePassResult(ir.passes.PassResult): class InlinePass(ir.passes.InPlacePass): + """Inline model local functions to the main graph and clear function definitions.""" + def __init__(self) -> None: super().__init__() self._functions: dict[ir.OperatorIdentifier, ir.Function] = {} diff --git a/onnxscript/ir/passes/common/onnx_checker.py b/onnxscript/ir/passes/common/onnx_checker.py index 18a5c03c5e..b815629641 100644 --- a/onnxscript/ir/passes/common/onnx_checker.py +++ b/onnxscript/ir/passes/common/onnx_checker.py @@ -8,6 +8,8 @@ "CheckerPass", ] +from typing import Literal + import onnx from onnxscript import ir @@ -18,11 +20,13 @@ class CheckerPass(ir.passes.PassBase): """Run onnx checker on the model.""" @property - def in_place(self) -> bool: + def in_place(self) -> Literal[True]: + """This pass does not create a new model.""" return True @property - def changes_input(self) -> bool: + def changes_input(self) -> Literal[False]: + """This pass does not change the input model.""" return False def __init__( diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index a6e8ea2fc5..3cfb9c5b04 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -16,8 +16,7 @@ import onnx -import onnxscript.ir.passes.common.inliner -import onnxscript.ir.passes.common.unused_removal +import onnxscript.ir.passes.common import onnxscript.optimizer._constant_folding as constant_folding from onnxscript import ir from onnxscript.optimizer._constant_folding import ( @@ -91,7 +90,7 @@ def optimize( def inline(model: ir.Model) -> None: """Inline all function calls (recursively) in the model.""" if model.functions: - onnxscript.ir.passes.common.inliner.InlinePass()(model) + onnxscript.ir.passes.common.InlinePass()(model) def fold_constants( @@ -115,12 +114,10 @@ def fold_constants( def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: """Removes unused nodes from a model inplace.""" if isinstance(model, ir.Model): - onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model) + onnxscript.ir.passes.common.RemoveUnusedNodesPass()(model) else: model_ir = ir.serde.deserialize_model(model) - model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()( - model_ir - ).model + model_ir = onnxscript.ir.passes.common.RemoveUnusedNodesPass()(model_ir).model new_proto = ir.serde.serialize_model(model_ir) model.Clear() model.CopyFrom(new_proto) @@ -129,12 +126,10 @@ def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None: """Removes unused functions from a model inplace.""" if isinstance(model, ir.Model): - onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()(model) + onnxscript.ir.passes.common.RemoveUnusedFunctionsPass()(model) else: model_ir = ir.serde.deserialize_model(model) - model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass()( - model_ir - ).model + model_ir = onnxscript.ir.passes.common.RemoveUnusedFunctionsPass()(model_ir).model new_proto = ir.serde.serialize_model(model_ir) model.Clear() model.CopyFrom(new_proto) diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index f8994bd741..40787c6e74 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -4,9 +4,7 @@ import logging -import onnxscript.ir.passes.common.constant_manipulation -import onnxscript.ir.passes.common.inliner -import onnxscript.ir.passes.common.unused_removal +import onnxscript.ir.passes.common from onnxscript import ir, rewriter from onnxscript.optimizer import _constant_folding @@ -45,20 +43,20 @@ def optimize_ir( output_size_limit=output_size_limit, ), rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES), - onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(), - onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass(), - onnxscript.ir.passes.common.unused_removal.RemoveUnusedOpsetsPass(), + onnxscript.ir.passes.common.RemoveUnusedNodesPass(), + onnxscript.ir.passes.common.RemoveUnusedFunctionsPass(), + onnxscript.ir.passes.common.RemoveUnusedOpsetsPass(), ], steps=num_iterations, early_stop=stop_if_no_change, ), - onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(), - onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(), - onnxscript.ir.passes.common.constant_manipulation.LiftSubgraphInitializersToMainGraphPass(), + onnxscript.ir.passes.common.RemoveUnusedNodesPass(), + onnxscript.ir.passes.common.LiftConstantsToInitializersPass(), + onnxscript.ir.passes.common.LiftSubgraphInitializersToMainGraphPass(), ] if inline: # Inline all functions first before optimizing - passes = [onnxscript.ir.passes.common.inliner.InlinePass(), *passes] + passes = [onnxscript.ir.passes.common.InlinePass(), *passes] optimizer_pass = ir.passes.Sequential(*passes) assert optimizer_pass.in_place result = optimizer_pass(model) diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index 12d909f7b1..89696d6986 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -11,10 +11,9 @@ import onnx +import onnxscript.ir.passes.common from onnxscript import ir from onnxscript.ir.passes.common import _c_api_utils -from onnxscript.ir.passes.common import inliner as _inliner -from onnxscript.ir.passes.common import unused_removal as _unused_removal from onnxscript.version_converter import _version_converter logger = logging.getLogger(__name__) @@ -40,14 +39,14 @@ def __init__(self, target_version: int, fallback: bool = False) -> None: self.target_version = target_version self.fallback = fallback self.convert_pass = ir.passes.Sequential( - _inliner.InlinePass(), + onnxscript.ir.passes.common.InlinePass(), _ConvertVersionPassRequiresInline( target_version=target_version, fallback=fallback, ), - _unused_removal.RemoveUnusedNodesPass(), - _unused_removal.RemoveUnusedFunctionsPass(), - _unused_removal.RemoveUnusedOpsetsPass(), + onnxscript.ir.passes.common.RemoveUnusedNodesPass(), + onnxscript.ir.passes.common.RemoveUnusedFunctionsPass(), + onnxscript.ir.passes.common.RemoveUnusedOpsetsPass(), ) def call(self, model: ir.Model) -> ir.passes.PassResult: @@ -78,7 +77,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: if model.functions: raise ValueError( "The model contains functions. The version conversion pass does not support " - "functions. Please use `onnxscript.ir.passes.common.inliner.InlinePass` to inline the " + "functions. Please use `onnxscript.ir.passes.common.InlinePass` to inline the " f"functions before applying this pass ({self.__class__.__name__})." ) if "" in model.graph.opset_imports: diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index a9f922ce25..8de86e3551 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -35,7 +35,7 @@ import onnxscript import onnxscript.evaluator -import onnxscript.ir.passes.common.unused_removal +import onnxscript.ir.passes.common from onnxscript import ir from onnxscript.function_libs.torch_lib.ops import common as common_ops from tests.function_libs.torch_lib import error_reproduction