Skip to content

Handle empty rewrite rules in rewrite function #2164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxscript/optimizer/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str:

class InlinePass(ir.passes.InPlacePass):
def __init__(self) -> None:
super().__init__()
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
self._opset_imports: dict[str, int] = {}
Expand Down
3 changes: 1 addition & 2 deletions onnxscript/optimizer/_legacy/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
inline_simple_functions,
)
from onnxscript.optimizer._legacy.constant_folding import fold_constants
from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,7 +74,7 @@ def optimize(
onnxscript.optimizer.remove_unused_functions(model)
inline_functions_with_unused_outputs(model)
# NOTE: This is general rewrite rules
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
model = rewriter.rewrite(model)
if stop_if_no_change and not modified:
logger.debug("Stopping after %d iterations.", _)
break
Expand Down
20 changes: 1 addition & 19 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,11 @@
import logging

import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer
from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
from onnxscript.rewriter import (
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
gemm_to_matmul_add,
llama_rule_sets,
no_op,
)

logger = logging.getLogger(__name__)

_DEFAULT_REWRITE_RULES: tuple[rewriter.pattern.RewriteRule, ...] = (
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule, # type: ignore[has-type]
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*llama_rule_sets.llama_p0_rule_set().rules,
)


def optimize_ir(
model: ir.Model,
Expand Down Expand Up @@ -61,7 +43,7 @@ def optimize_ir(
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
),
rewriter.RewritePass(_DEFAULT_REWRITE_RULES),
rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass(),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedOpsetsPass(),
Expand Down
69 changes: 52 additions & 17 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,81 @@
from typing import Sequence, TypeVar, Union

__all__ = [
# Modules
"pattern",
# Functions
"rewrite",
"RewritePass",
]

import onnx

from onnxscript import ir
from onnxscript.ir.passes.common import unused_removal
from onnxscript.rewriter import pattern
from onnxscript.rewriter import (
broadcast_to_matmul,
cast_constant_of_shape,
collapse_slices,
gemm_to_matmul_add,
llama_rule_sets,
no_op,
pattern,
)

PatternRewriteRule = pattern.RewriteRule

ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model)
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule, # type: ignore[has-type]
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*llama_rule_sets.llama_p0_rule_set().rules,
)


class RewritePass(ir.passes.InPlacePass):
def __init__(
self,
pattern_rewrite_rules: Sequence[PatternRewriteRule] | pattern.RewriteRuleSet = (),
rules: Sequence[pattern.RewriteRule] | pattern.RewriteRuleSet,
/,
) -> None:
if pattern_rewrite_rules:
if not isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet):
# Create a pattern rule-set using provided rules
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
assert isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet)
self.pattern_rewrite_rules: pattern.RewriteRuleSet = pattern_rewrite_rules
super().__init__()
if isinstance(rules, Sequence):
if not rules:
raise ValueError("rules must not be empty")
# Create a pattern rule-set using provided rules
rules = pattern.RewriteRuleSet(rules)
assert isinstance(rules, pattern.RewriteRuleSet)
self.rules: pattern.RewriteRuleSet = rules

def call(self, model: ir.Model) -> ir.passes.PassResult:
count = self.pattern_rewrite_rules.apply_to_model(model)
count = self.rules.apply_to_model(model)
if count:
print(f"Applied {count} of general pattern rewrite rules.")
return ir.passes.PassResult(model, bool(count))


def rewrite(
model: ModelProtoOrIr,
pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], pattern.RewriteRuleSet] = (),
) -> ModelProtoOrIr:
model: _ModelProtoOrIr,
pattern_rewrite_rules: Union[Sequence[pattern.RewriteRule], pattern.RewriteRuleSet]
| None = None,
) -> _ModelProtoOrIr:
"""Rewrite the model using the provided pattern rewrite rules.

Unused nodes, functions, and opsets will be removed after the rewrite.

Args:
model: The model to be rewritten. Can be an ONNX ModelProto or an ir.Model.
pattern_rewrite_rules: A sequence of pattern rewrite rules or a RewriteRuleSet.
If not provided, default rules will be applied. If empty, no rules will be applied
and the original model will be returned.

Returns:
The rewritten model as the same type as the input model.
"""
if pattern_rewrite_rules is None:
pattern_rewrite_rules = _DEFAULT_REWRITE_RULES
elif not pattern_rewrite_rules:
return model

if isinstance(model, onnx.ModelProto):
model_ir = ir.serde.deserialize_model(model)
proto = True
Expand Down
5 changes: 5 additions & 0 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,13 +1664,18 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str:

class RewriteRuleSet:
def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None:
if not rules:
raise ValueError("rules must contain at least one rule")
if commute:
rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules]))
self.rules = rules
# We call remove_unused_nodes at end of rewriting if there is any rule that does
# NOT remove nodes (immediately when it is applied)
self.remove_unused_nodes = any(not rule.remove_nodes for rule in rules)

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.rules})"

def _apply_to_graph_or_function(
self,
model: ir.Model,
Expand Down
Loading