Skip to content

Commit 91da7b8

Browse files
justinchubybmehta001
authored andcommitted
Turn inliner into a pass and use it in rewriter & optimizer (microsoft#2149)
Use passes in optimizer and rewriter. 1. By opting into using the pass infra early, we get the benefit of getting the additional features in pass infra w/o having to pay higher refactoring cost in the future. We will be able to add more sophisticated debug utilities/snapshot capabilities etc. to the passes. 2. Since we are offering the pass infra to users, we can start validating it internally by using it here. If order altering becomes a valid use case we can expect users may need that and we can create appropriate facilities to support the usage.
1 parent 69dbc7b commit 91da7b8

File tree

3 files changed

+71
-39
lines changed

3 files changed

+71
-39
lines changed

onnxscript/optimizer/_inliner.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class _CopyReplace:
4747

4848
def __init__(
4949
self,
50-
inliner: _Inliner,
50+
inliner: InlinePass,
5151
attr_map: dict[str, ir.Attr | ir.RefAttr],
5252
value_map: dict[ir.Value, ir.Value | None],
5353
metadata_props: dict[str, str],
@@ -188,15 +188,29 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
188188
return {id: id_abbreviation(id) for id in function_ids}
189189

190190

191-
class _Inliner:
192-
def __init__(self, model: ir.Model) -> None:
193-
self._functions = model.functions
194-
self._function_id_abbreviations = _abbreviate(self._functions.keys())
195-
self._opset_imports = model.opset_imports
191+
class InlinePass(ir.passes.InPlacePass):
192+
def __init__(self) -> None:
193+
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
194+
self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
195+
self._opset_imports: dict[str, int] = {}
196196
self.used_value_names: set[str] = set()
197197
self.used_node_names: set[str] = set()
198198
self.node_context: dict[ir.Node, CallStack] = {}
199199

200+
def _reset(self, model: ir.Model) -> None:
201+
self._functions = model.functions
202+
self._function_id_abbreviations = _abbreviate(self._functions.keys())
203+
self._opset_imports = model.opset_imports
204+
self.used_value_names = set()
205+
self.used_node_names = set()
206+
self.node_context = {}
207+
208+
def call(self, model: ir.Model) -> ir.passes.PassResult:
209+
self._reset(model)
210+
modified = self.inline_calls_in(model.graph)
211+
model.functions.clear()
212+
return ir.passes.PassResult(model, modified)
213+
200214
def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
201215
id = node.op_identifier()
202216
function = self._functions[id]
@@ -249,7 +263,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl
249263
output_values = [value_map[output] for output in function.outputs]
250264
return nodes, output_values # type: ignore
251265

252-
def inline_calls_in(self, graph: ir.Graph) -> None:
266+
def inline_calls_in(self, graph: ir.Graph) -> bool:
253267
for input in graph.inputs:
254268
if input.name is not None:
255269
self.used_value_names.add(input.name)
@@ -302,11 +316,10 @@ def inline_calls_in(self, graph: ir.Graph) -> None:
302316
elif attr.type == ir.AttributeType.GRAPHS:
303317
for graph in attr.as_graphs():
304318
self.inline_calls_in(graph)
319+
return bool(id_count)
305320

306321

307322
def inline(model: ir.Model) -> None:
308323
"""Inline all function calls (recursively) in the model."""
309324
if model.functions:
310-
inliner = _Inliner(model)
311-
inliner.inline_calls_in(model.graph)
312-
model.functions.clear()
325+
InlinePass()(model)

onnxscript/optimizer/_optimizer.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import logging
66

7+
import onnxscript.ir.passes.common.unused_removal
78
import onnxscript.optimizer
89
from onnxscript import ir, rewriter
910
from onnxscript.optimizer import _constant_folding, _inliner
@@ -50,15 +51,26 @@ def optimize_ir(
5051
stop_if_no_change: Not supported currently (has no effect). Meant to stop the
5152
outer optimization loop if no change is detected in one iteration.
5253
"""
53-
del stop_if_no_change # Looks like rewriter doesn't support this yet.
54-
# TODO(justinchuby): Update this to use a pass manager
55-
_inliner.inline(model)
56-
for _ in range(num_iterations):
57-
_constant_folding.fold_constants(
58-
model,
59-
onnx_shape_inference=onnx_shape_inference,
60-
input_size_limit=input_size_limit,
61-
output_size_limit=output_size_limit,
62-
)
63-
rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
64-
onnxscript.optimizer.remove_unused_nodes(model)
54+
optimizer_pass = ir.passes.Sequential(
55+
_inliner.InlinePass(),
56+
ir.passes.PassManager(
57+
[
58+
_constant_folding.FoldConstantsPass(
59+
external_data_folder="",
60+
shape_inference=onnx_shape_inference,
61+
input_size_limit=input_size_limit,
62+
output_size_limit=output_size_limit,
63+
),
64+
rewriter.RewritePass(_DEFAULT_REWRITE_RULES),
65+
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
66+
onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass(),
67+
onnxscript.ir.passes.common.unused_removal.RemoveUnusedOpsetsPass(),
68+
],
69+
steps=num_iterations,
70+
early_stop=stop_if_no_change,
71+
),
72+
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
73+
)
74+
assert optimizer_pass.in_place
75+
result = optimizer_pass(model)
76+
assert result.model is model

onnxscript/rewriter/__init__.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,44 +18,51 @@
1818
from onnxscript.ir.passes.common import unused_removal
1919
from onnxscript.rewriter import pattern
2020

21-
RewriteRuleSet = pattern.RewriteRuleSet
2221
PatternRewriteRule = pattern.RewriteRule
2322
FunctionRewriteRule = function_rule.FunctionRewriteRule
2423

2524
ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model)
2625

2726

27+
class RewritePass(ir.passes.InPlacePass):
28+
def __init__(
29+
self,
30+
pattern_rewrite_rules: Sequence[PatternRewriteRule] | pattern.RewriteRuleSet = (),
31+
) -> None:
32+
if pattern_rewrite_rules:
33+
if not isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet):
34+
# Create a pattern rule-set using provided rules
35+
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
36+
assert isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet)
37+
self.pattern_rewrite_rules: pattern.RewriteRuleSet = pattern_rewrite_rules
38+
39+
def call(self, model: ir.Model) -> ir.passes.PassResult:
40+
count = self.pattern_rewrite_rules.apply_to_model(model)
41+
if count:
42+
print(f"Applied {count} of general pattern rewrite rules.")
43+
return ir.passes.PassResult(model, bool(count))
44+
45+
2846
def rewrite(
2947
model: ModelProtoOrIr,
30-
function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (),
31-
pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], RewriteRuleSet] = (),
48+
pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], pattern.RewriteRuleSet] = (),
3249
) -> ModelProtoOrIr:
3350
if isinstance(model, onnx.ModelProto):
3451
model_ir = ir.serde.deserialize_model(model)
3552
proto = True
3653
else:
3754
model_ir = model
3855
proto = False
39-
if function_rewrite_rules:
40-
for rule_cls in function_rewrite_rules:
41-
count, model_ir = rule_cls().apply_to_model(model_ir)
42-
if count > 0:
43-
print(f"Applied {count} of rewrite rules.")
44-
if pattern_rewrite_rules:
45-
if not isinstance(pattern_rewrite_rules, RewriteRuleSet):
46-
# Create a pattern rule-set using provided rules
47-
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
48-
count = pattern_rewrite_rules.apply_to_model(model_ir)
49-
if count:
50-
print(f"Applied {count} of general pattern rewrite rules.")
51-
unused_remover = ir.passes.PassManager(
56+
57+
rewrite_pass = ir.passes.PassManager(
5258
(
59+
RewritePass(pattern_rewrite_rules),
5360
unused_removal.RemoveUnusedNodesPass(),
5461
unused_removal.RemoveUnusedFunctionsPass(),
5562
unused_removal.RemoveUnusedOpsetsPass(),
5663
)
5764
)
58-
model_ir = unused_remover(model_ir).model
65+
model_ir = rewrite_pass(model_ir).model
5966
if proto:
6067
return ir.serde.serialize_model(model_ir)
6168
return model_ir # type: ignore[return-value]

0 commit comments

Comments
 (0)