Skip to content

Commit 078b27f

Browse files
authored
Handle empty rewrite rules in rewrite function (#2164)
In #2149 the logic for skipping rewrite when no rules are provided was removed. This PR adds the logic back and hardens input checks. Now if no rules are provided to `rewrite()`, it will only run cleanup passes.
1 parent 58aeccd commit 078b27f

File tree

5 files changed

+60
-38
lines changed

5 files changed

+60
-38
lines changed

onnxscript/optimizer/_inliner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
190190

191191
class InlinePass(ir.passes.InPlacePass):
192192
def __init__(self) -> None:
193+
super().__init__()
193194
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
194195
self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
195196
self._opset_imports: dict[str, int] = {}

onnxscript/optimizer/_legacy/_optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
inline_simple_functions,
1616
)
1717
from onnxscript.optimizer._legacy.constant_folding import fold_constants
18-
from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES
1918

2019
logger = logging.getLogger(__name__)
2120

@@ -75,7 +74,7 @@ def optimize(
7574
onnxscript.optimizer.remove_unused_functions(model)
7675
inline_functions_with_unused_outputs(model)
7776
# NOTE: This is general rewrite rules
78-
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
77+
model = rewriter.rewrite(model)
7978
if stop_if_no_change and not modified:
8079
logger.debug("Stopping after %d iterations.", _)
8180
break

onnxscript/optimizer/_optimizer.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,11 @@
55
import logging
66

77
import onnxscript.ir.passes.common.unused_removal
8-
import onnxscript.optimizer
98
from onnxscript import ir, rewriter
109
from onnxscript.optimizer import _constant_folding, _inliner
11-
from onnxscript.rewriter import (
12-
broadcast_to_matmul,
13-
cast_constant_of_shape,
14-
collapse_slices,
15-
gemm_to_matmul_add,
16-
llama_rule_sets,
17-
no_op,
18-
)
1910

2011
logger = logging.getLogger(__name__)
2112

22-
_DEFAULT_REWRITE_RULES: tuple[rewriter.pattern.RewriteRule, ...] = (
23-
*no_op.rules.rules, # TODO: merge this rule into constant folding?
24-
*broadcast_to_matmul.rules.rules,
25-
gemm_to_matmul_add.rule, # type: ignore[has-type]
26-
*cast_constant_of_shape.rules.rules,
27-
*collapse_slices.rules.rules,
28-
*llama_rule_sets.llama_p0_rule_set().rules,
29-
)
30-
3113

3214
def optimize_ir(
3315
model: ir.Model,
@@ -61,7 +43,7 @@ def optimize_ir(
6143
input_size_limit=input_size_limit,
6244
output_size_limit=output_size_limit,
6345
),
64-
rewriter.RewritePass(_DEFAULT_REWRITE_RULES),
46+
rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
6547
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
6648
onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass(),
6749
onnxscript.ir.passes.common.unused_removal.RemoveUnusedOpsetsPass(),

onnxscript/rewriter/__init__.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,46 +5,81 @@
55
from typing import Sequence, TypeVar, Union
66

77
__all__ = [
8-
# Modules
98
"pattern",
10-
# Functions
119
"rewrite",
10+
"RewritePass",
1211
]
1312

1413
import onnx
1514

1615
from onnxscript import ir
1716
from onnxscript.ir.passes.common import unused_removal
18-
from onnxscript.rewriter import pattern
17+
from onnxscript.rewriter import (
18+
broadcast_to_matmul,
19+
cast_constant_of_shape,
20+
collapse_slices,
21+
gemm_to_matmul_add,
22+
llama_rule_sets,
23+
no_op,
24+
pattern,
25+
)
1926

20-
PatternRewriteRule = pattern.RewriteRule
21-
22-
ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model)
27+
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
28+
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
29+
*no_op.rules.rules, # TODO: merge this rule into constant folding?
30+
*broadcast_to_matmul.rules.rules,
31+
gemm_to_matmul_add.rule, # type: ignore[has-type]
32+
*cast_constant_of_shape.rules.rules,
33+
*collapse_slices.rules.rules,
34+
*llama_rule_sets.llama_p0_rule_set().rules,
35+
)
2336

2437

2538
class RewritePass(ir.passes.InPlacePass):
2639
def __init__(
2740
self,
28-
pattern_rewrite_rules: Sequence[PatternRewriteRule] | pattern.RewriteRuleSet = (),
41+
rules: Sequence[pattern.RewriteRule] | pattern.RewriteRuleSet,
42+
/,
2943
) -> None:
30-
if pattern_rewrite_rules:
31-
if not isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet):
32-
# Create a pattern rule-set using provided rules
33-
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
34-
assert isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet)
35-
self.pattern_rewrite_rules: pattern.RewriteRuleSet = pattern_rewrite_rules
44+
super().__init__()
45+
if isinstance(rules, Sequence):
46+
if not rules:
47+
raise ValueError("rules must not be empty")
48+
# Create a pattern rule-set using provided rules
49+
rules = pattern.RewriteRuleSet(rules)
50+
assert isinstance(rules, pattern.RewriteRuleSet)
51+
self.rules: pattern.RewriteRuleSet = rules
3652

3753
def call(self, model: ir.Model) -> ir.passes.PassResult:
38-
count = self.pattern_rewrite_rules.apply_to_model(model)
54+
count = self.rules.apply_to_model(model)
3955
if count:
4056
print(f"Applied {count} of general pattern rewrite rules.")
4157
return ir.passes.PassResult(model, bool(count))
4258

4359

4460
def rewrite(
45-
model: ModelProtoOrIr,
46-
pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], pattern.RewriteRuleSet] = (),
47-
) -> ModelProtoOrIr:
61+
model: _ModelProtoOrIr,
62+
pattern_rewrite_rules: Union[Sequence[pattern.RewriteRule], pattern.RewriteRuleSet]
63+
| None = None,
64+
) -> _ModelProtoOrIr:
65+
"""Rewrite the model using the provided pattern rewrite rules.
66+
67+
Unused nodes, functions, and opsets will be removed after the rewrite.
68+
69+
Args:
70+
model: The model to be rewritten. Can be an ONNX ModelProto or an ir.Model.
71+
pattern_rewrite_rules: A sequence of pattern rewrite rules or a RewriteRuleSet.
72+
If not provided, default rules will be applied. If empty, no rules will be applied
73+
and the original model will be returned.
74+
75+
Returns:
76+
The rewritten model as the same type as the input model.
77+
"""
78+
if pattern_rewrite_rules is None:
79+
pattern_rewrite_rules = _DEFAULT_REWRITE_RULES
80+
elif not pattern_rewrite_rules:
81+
return model
82+
4883
if isinstance(model, onnx.ModelProto):
4984
model_ir = ir.serde.deserialize_model(model)
5085
proto = True

onnxscript/rewriter/pattern.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,13 +1664,18 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str:
16641664

16651665
class RewriteRuleSet:
16661666
def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None:
1667+
if not rules:
1668+
raise ValueError("rules must contain at least one rule")
16671669
if commute:
16681670
rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules]))
16691671
self.rules = rules
16701672
# We call remove_unused_nodes at end of rewriting if there is any rule that does
16711673
# NOT remove nodes (immediately when it is applied)
16721674
self.remove_unused_nodes = any(not rule.remove_nodes for rule in rules)
16731675

1676+
def __repr__(self) -> str:
1677+
return f"{self.__class__.__name__}({self.rules})"
1678+
16741679
def _apply_to_graph_or_function(
16751680
self,
16761681
model: ir.Model,

0 commit comments

Comments
 (0)