Skip to content

Commit e13ab6e

Browse files
modify to add callable
1 parent ba51611 commit e13ab6e

File tree

9 files changed

+42
-28
lines changed

9 files changed

+42
-28
lines changed

onnxscript/rewriter/_fusion_utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
from typing import Sequence, Union
5+
from typing import Callable, Sequence, Union
66

7-
from onnxscript import ir
7+
8+
import onnxscript.ir as ir
9+
from onnxscript.rewriter import pattern
810

911
Dim = Union[int, ir.SymbolicDim]
1012

@@ -20,20 +22,20 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str])
2022
elif actual != bindings[expected]:
2123
return False
2224
return True
23-
import onnxscript.ir as ir
24-
from onnxscript.rewriter import pattern
2525

2626

27-
def apply_fusion_rules(
28-
rules: pattern.RewriteRule | pattern.RewriteRuleSet, model: ir.Model, debug: bool = False
29-
) -> int:
27+
def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable:
3028
"""
3129
Apply the given fusion rules to the model and return the number of fusions applied.
3230
If debug is True, enable pattern matching tracer for debugging.
3331
"""
34-
count = rules.apply_to_model(model)
35-
if count == 0 and debug:
36-
tracer = pattern.MatchingTracer()
37-
rules.apply_to_model(model, tracer=tracer)
38-
tracer.report()
39-
return count
32+
33+
def apply_to(model: ir.Model, debug: bool = False) -> int:
34+
count = rules.apply_to_model(model)
35+
if count == 0 and debug:
36+
tracer = pattern.MatchingTracer()
37+
rules.apply_to_model(model, tracer=tracer)
38+
tracer.report()
39+
return count
40+
41+
return apply_to

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def rewrite(
170170

171171

172172
def fuse_cos_sin_cache(model: ir.Model, debug: bool = False) -> int:
173-
count = _fusion_utils.apply_fusion_rules(cos_sin_cache_rules, model, debug=debug)
173+
fuse_cos_sin_cache = _fusion_utils.apply_fusion_rules(cos_sin_cache_rules)
174+
count = fuse_cos_sin_cache(model, debug=debug)
174175
if count != 0:
175176
remove_unused_nodes(model)
176177
return count

onnxscript/rewriter/ort_fusions/gelu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ def rewrite(self, op, x):
3434

3535

3636
def fuse_gelu(model: ir.Model, debug: bool = False) -> int:
37-
return _fusion_utils.apply_fusion_rules(gelu_rules, model, debug=debug)
37+
fuse_gelu = _fusion_utils.apply_fusion_rules(gelu_rules)
38+
return fuse_gelu(model, debug=debug)

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def rewrite(
151151

152152

153153
def fuse_gqa(model: ir.Model, debug: bool = False) -> int:
154-
count = _fusion_utils.apply_fusion_rules(gqa_rules, model, debug=debug)
154+
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)
155+
count = fuse_gqa(model, debug=debug)
155156
remove_unused_nodes(model)
156157
return count

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,4 +261,5 @@ def rewrite(
261261

262262

263263
def fuse_mha(model: ir.Model, *, debug: bool = False) -> int:
264-
return _fusion_utils.apply_fusion_rules(mha_rules, model, debug=debug)
264+
fuse_mha = _fusion_utils.apply_fusion_rules(mha_rules)
265+
return fuse_mha(model, debug=debug)

onnxscript/rewriter/ort_fusions/rms_normalization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,5 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):
9292

9393

9494
def fuse_rms_normalization(model: ir.Model, debug: bool = False) -> int:
95-
return _fusion_utils.apply_fusion_rules(rms_normalization_ruleset, model, debug=debug)
95+
fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset)
96+
return fuse_rms_normalization(model, debug=debug)

onnxscript/rewriter/ort_fusions/rotary_embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,10 @@ def rewrite(self, op, x, end1, x_part_1_rope, **_):
121121

122122

123123
def fuse_rotary_embedding(model: ir.Model, debug: bool = False) -> int:
124-
return _fusion_utils.apply_fusion_rules(rotary_embedding_rules, model, debug=debug)
124+
fuse_rotary_embedding = _fusion_utils.apply_fusion_rules(rotary_embedding_rules)
125+
return fuse_rotary_embedding(model, debug=debug)
125126

126127

127128
def fuse_partial_rotary_embedding(model: ir.Model, debug: bool = False) -> int:
128-
return _fusion_utils.apply_fusion_rules(partial_embedding_rules, model, debug=debug)
129+
fuse_partial_rotary_embedding = _fusion_utils.apply_fusion_rules(partial_embedding_rules)
130+
return fuse_partial_rotary_embedding(model, debug=debug)

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,5 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
131131

132132

133133
def fuse_sdpa(model: ir.Model, debug: bool = False) -> int:
134-
return _fusion_utils.apply_fusion_rules(sdpa_rules, model, debug=debug)
134+
fuse_sdpa = _fusion_utils.apply_fusion_rules(sdpa_rules)
135+
return fuse_sdpa(model, debug=debug)

onnxscript/rewriter/ort_fusions/skip_normalization.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from onnxscript.rewriter import _fusion_utils, pattern
77

88

9-
def _skip_rmsnorm_pattern(op, input, skip, gamma, epsilon, stash_type):
9+
def _skip_rms_norm_pattern(op, input, skip, gamma, epsilon, stash_type):
1010
skip_sum = op.Add(input, skip)
1111
normalized = op.SimplifiedLayerNormalization(
1212
skip_sum,
@@ -33,14 +33,14 @@ def _skip_rms_normalization(op, input, skip, gamma, epsilon, stash_type):
3333

3434

3535
_skip_rms_rule = pattern.RewriteRule(
36-
_skip_rmsnorm_pattern, _skip_rms_normalization, matcher=pattern.SimplePatternMatcher
36+
_skip_rms_norm_pattern, _skip_rms_normalization, matcher=pattern.SimplePatternMatcher
3737
)
3838

3939
skip_rms_normalization_rules = [_skip_rms_rule]
4040
skip_rms_normalization_ruleset = pattern.RewriteRuleSet(skip_rms_normalization_rules)
4141

4242

43-
def _skip_layernorm_pattern(op, input, skip, gamma, beta, epsilon, stash_type):
43+
def _skip_layer_norm_pattern(op, input, skip, gamma, beta, epsilon, stash_type):
4444
skip_sum = op.Add(input, skip)
4545
normalized = op.LayerNormalization(
4646
skip_sum,
@@ -69,18 +69,22 @@ def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type)
6969

7070

7171
_skip_layer_rule = pattern.RewriteRule(
72-
_skip_layernorm_pattern, _skip_layer_normalization, matcher=pattern.SimplePatternMatcher
72+
_skip_layer_norm_pattern, _skip_layer_normalization, matcher=pattern.SimplePatternMatcher
7373
)
7474

7575
skip_layer_normalization_rules = [_skip_layer_rule]
7676
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules)
7777

7878

7979
def fuse_skip_rms_normalization(model: ir.Model, debug: bool = False) -> int:
80-
return _fusion_utils.apply_fusion_rules(skip_rms_normalization_ruleset, model, debug=debug)
80+
fuse_skip_rms_normalization = _fusion_utils.apply_fusion_rules(
81+
skip_rms_normalization_ruleset
82+
)
83+
return fuse_skip_rms_normalization(model, debug=debug)
8184

8285

8386
def fuse_skip_layer_normalization(model: ir.Model, debug: bool = False) -> int:
84-
return _fusion_utils.apply_fusion_rules(
85-
skip_layer_normalization_ruleset, model, debug=debug
87+
fuse_skip_layer_normalization = _fusion_utils.apply_fusion_rules(
88+
skip_layer_normalization_ruleset
8689
)
90+
return fuse_skip_layer_normalization(model, debug=debug)

0 commit comments

Comments
 (0)