Skip to content

Commit 35fd23c

Browse files
shubhambhokare1bmehta001
authored andcommitted
Allow fuse_xformers to return a count of different fusions applied (microsoft#2159)
1 parent 0be1074 commit 35fd23c

13 files changed

+124
-60
lines changed

onnxscript/rewriter/_fusion_utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
from typing import Sequence, Union
5+
from typing import Callable, Sequence, Union
66

7-
from onnxscript import ir
7+
import onnxscript.ir as ir
8+
from onnxscript.rewriter import pattern
89

910
Dim = Union[int, ir.SymbolicDim]
1011

@@ -20,3 +21,20 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str])
2021
elif actual != bindings[expected]:
2122
return False
2223
return True
24+
25+
26+
def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable:
27+
"""
28+
Apply the given fusion rules to the model and return the number of fusions applied.
29+
If debug is True, enable pattern matching tracer for debugging.
30+
"""
31+
32+
def apply_to(model: ir.Model, debug: bool = False) -> int:
33+
count = rules.apply_to_model(model)
34+
if count == 0 and debug:
35+
tracer = pattern.MatchingTracer()
36+
rules.apply_to_model(model, tracer=tracer)
37+
tracer.report()
38+
return count
39+
40+
return apply_to

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
fuse_rotary_embedding,
1515
)
1616
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
17-
from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_normalization
17+
from onnxscript.rewriter.ort_fusions.skip_normalization import (
18+
fuse_skip_layer_normalization,
19+
fuse_skip_rms_normalization,
20+
)
1821

1922

2023
# Preliminary optimizations before applying the transformer fusions.
@@ -29,21 +32,35 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
2932
return model
3033

3134

32-
def fuse_xformers(model: ir.Model) -> ir.Model:
35+
def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
36+
"""
37+
Apply transformer-specific fusions to the given model.
38+
39+
Args:
40+
model: The input ONNX model represented as an `ir.Model`.
41+
42+
Returns:
43+
A tuple containing:
44+
- The optimized `ir.Model` after applying transformer-specific fusions.
45+
- A dictionary with a count of each of the fusions applied.
46+
"""
47+
fusion_count = dict()
48+
3349
model = _pre_optimize(model)
34-
fuse_rms_normalization(model)
35-
fuse_normalization(model)
36-
fuse_rotary_embedding(model)
37-
fuse_partial_rotary_embedding(model)
38-
fuse_cos_sin_cache(model)
39-
fuse_sdpa(model)
40-
fuse_mha(model)
41-
fuse_attention(model)
42-
fuse_gelu(model)
50+
fusion_count["rms_normalization"] = fuse_rms_normalization(model)
51+
fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model)
52+
fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model)
53+
fusion_count["rotary_embedding"] = fuse_rotary_embedding(model)
54+
fusion_count["partial_rotary_embedding"] = fuse_partial_rotary_embedding(model)
55+
fusion_count["cos_sin_cache"] = fuse_cos_sin_cache(model)
56+
fusion_count["sdpa"] = fuse_sdpa(model)
57+
fusion_count["mha"] = fuse_mha(model)
58+
fusion_count["attention"] = fuse_attention(model)
59+
fusion_count["gelu"] = fuse_gelu(model)
4360
# Finally: inline any intermediate fusion functions introduced that were not
4461
# consumed by other fusions, and eliminate any remaining unused nodes.
4562
optimize(model)
46-
return model
63+
return model, fusion_count
4764

4865

4966
def optimize_for_ort(model: ir.Model) -> None:

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,4 @@ def rewrite(
268268
)
269269

270270

271-
def fuse_attention(model: ir.Model, *, debug: bool = False) -> int:
272-
count = attention_rules.apply_to_model(model)
273-
if debug and count == 0:
274-
tracer = pattern.MatchingTracer()
275-
attention_rules.apply_to_model(model, tracer=tracer)
276-
tracer.report()
277-
return count
271+
fuse_attention = _fusion_utils.apply_fusion_rules(attention_rules)

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import numpy as np
66

77
import onnxscript.ir as ir
8-
from onnxscript.optimizer import remove_unused_nodes
9-
from onnxscript.rewriter import _ir_utils, pattern
8+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
109

1110
# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops.
1211

onnxscript/rewriter/ort_fusions/fuse_xformers_test.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,20 @@ def test_fuse_xformers(self):
1717
onnxscript.optimizer.optimize(model)
1818
inputs = test.get_ort_inputs()
1919
original_outputs = ort_run("original", model, inputs)
20-
model = fuse_xformers(model)
20+
model, fusion_count = fuse_xformers(model)
21+
22+
# Check if the number of fusions applied for each fusion is correct
23+
self.assertEqual(fusion_count["rms_normalization"], 3)
24+
self.assertEqual(fusion_count["skip_layer_normalization"], 0)
25+
self.assertEqual(fusion_count["skip_rms_normalization"], 2)
26+
self.assertEqual(fusion_count["rotary_embedding"], 2)
27+
self.assertEqual(fusion_count["partial_rotary_embedding"], 0)
28+
self.assertEqual(fusion_count["cos_sin_cache"], 2)
29+
self.assertEqual(fusion_count["sdpa"], 1)
30+
self.assertEqual(fusion_count["mha"], 0)
31+
self.assertEqual(fusion_count["attention"], 0)
32+
self.assertEqual(fusion_count["gelu"], 0)
33+
2134
new_outputs = ort_run("optimized", model, inputs)
2235
assert_allclose(new_outputs, original_outputs)
2336

onnxscript/rewriter/ort_fusions/gelu.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
import math
66

7-
from onnxscript import ir
8-
from onnxscript.rewriter import pattern
7+
from onnxscript.rewriter import _fusion_utils, pattern
98

109
_sqrt_two_over_pi = math.sqrt(2.0 / math.pi)
1110

@@ -33,5 +32,4 @@ def rewrite(self, op, x):
3332
gelu_rules = pattern.RewriteRuleSet([_rule])
3433

3534

36-
def fuse_gelu(model: ir.Model) -> None:
37-
gelu_rules.apply_to_model(model)
35+
fuse_gelu = _fusion_utils.apply_fusion_rules(gelu_rules)

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
import onnxscript.ir as ir
6-
from onnxscript.optimizer import remove_unused_nodes
7-
from onnxscript.rewriter import pattern
5+
from onnxscript.rewriter import _fusion_utils, pattern
86

97

108
class GroupQueryAttention(pattern.RewriteRuleClassBase):
@@ -150,8 +148,4 @@ def rewrite(
150148
gqa_rules = pattern.RewriteRuleSet([_rule1])
151149

152150

153-
def fuse_gqa(model: ir.Model) -> int:
154-
count = gqa_rules.apply_to_model(model)
155-
print(f"GQA count: {count}")
156-
remove_unused_nodes(model)
157-
return count
151+
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)

onnxscript/rewriter/ort_fusions/mha_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_smollm(self):
1717
model = smollm_test.get_onnx_model()
1818
onnxscript.optimizer.optimize(model)
1919
xformers.fuse_rms_normalization(model)
20-
xformers.fuse_normalization(model)
20+
xformers.fuse_skip_rms_normalization(model)
2121
xformers.fuse_rotary_embedding(model)
2222
xformers.fuse_cos_sin_cache(model)
2323

onnxscript/rewriter/ort_fusions/rms_normalization.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import onnxscript.ir as ir
6-
from onnxscript.rewriter import _ir_utils, pattern
6+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
77

88
"""
99
RMS Normalization: This is referred to as SimplifiedLayerNormalization in the ORT codebase.
@@ -91,6 +91,4 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):
9191
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)
9292

9393

94-
def fuse_rms_normalization(model: ir.Model) -> None:
95-
count = rms_normalization_ruleset.apply_to_model(model)
96-
print(f"RMS Normalization count: {count}")
94+
fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset)

onnxscript/rewriter/ort_fusions/rotary_embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
import onnxscript.ir as ir
6-
from onnxscript.rewriter import _ir_utils, pattern
5+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
76

87
# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern
98
# for full rotation without interleaving.

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
import math
66

7-
import onnxscript.ir as ir
8-
from onnxscript.rewriter import _ir_utils, pattern
7+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
98

109

1110
class SDPA(pattern.RewriteRuleClassBase):

onnxscript/rewriter/ort_fusions/skip_normalization.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
from onnxscript.rewriter import pattern
6-
from onnxscript.rewriter.ort_fusions.rms_normalization import rms_normalization_rules
5+
from onnxscript.rewriter import _fusion_utils, pattern
76

87

9-
def _skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type):
8+
def _skip_rms_norm_pattern(op, input, skip, gamma, epsilon, stash_type):
109
skip_sum = op.Add(input, skip)
1110
normalized = op.SimplifiedLayerNormalization(
1211
skip_sum,
@@ -18,7 +17,7 @@ def _skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type):
1817
return normalized, skip_sum
1918

2019

21-
def _skip_normalization(op, input, skip, gamma, epsilon, stash_type):
20+
def _skip_rms_normalization(op, input, skip, gamma, epsilon, stash_type):
2221
if stash_type.value != 1: # FLOAT type
2322
return None
2423
normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization(
@@ -32,15 +31,49 @@ def _skip_normalization(op, input, skip, gamma, epsilon, stash_type):
3231
return normalized, skip_sum
3332

3433

35-
_rule = pattern.RewriteRule(
36-
_skip_norm_pattern, _skip_normalization, matcher=pattern.SimplePatternMatcher
37-
)
34+
_skip_rms_rule = pattern.RewriteRule(_skip_rms_norm_pattern, _skip_rms_normalization)
35+
36+
skip_rms_normalization_rules = [_skip_rms_rule]
37+
skip_rms_normalization_ruleset = pattern.RewriteRuleSet(skip_rms_normalization_rules)
38+
39+
40+
def _skip_layer_norm_pattern(op, input, skip, gamma, beta, epsilon, stash_type):
41+
skip_sum = op.Add(input, skip)
42+
normalized = op.LayerNormalization(
43+
skip_sum,
44+
gamma,
45+
beta,
46+
axis=-1,
47+
epsilon=epsilon,
48+
stash_type=stash_type,
49+
)
50+
return normalized
51+
52+
53+
def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type):
54+
if stash_type.value != 1: # FLOAT type
55+
return None
56+
normalized, _mean, _inv_std_var = op.SkipLayerNormalization(
57+
input,
58+
skip,
59+
gamma,
60+
beta,
61+
epsilon=epsilon,
62+
_outputs=3,
63+
_domain="com.microsoft",
64+
)
65+
return normalized
66+
67+
68+
_skip_layer_rule = pattern.RewriteRule(_skip_layer_norm_pattern, _skip_layer_normalization)
3869

39-
skip_normalization_rules = [_rule]
40-
normalization_rules = rms_normalization_rules + skip_normalization_rules
41-
normalization_ruleset = pattern.RewriteRuleSet(normalization_rules)
70+
skip_layer_normalization_rules = [_skip_layer_rule]
71+
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules)
4272

4373

44-
def fuse_normalization(model):
45-
count = normalization_ruleset.apply_to_model(model)
46-
print(f"Normalization count: {count}")
74+
fuse_skip_rms_normalization = _fusion_utils.apply_fusion_rules(skip_rms_normalization_ruleset)
75+
76+
77+
fuse_skip_layer_normalization = _fusion_utils.apply_fusion_rules(
78+
skip_layer_normalization_ruleset
79+
)

onnxscript/rewriter/ort_fusions/skip_normalization_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import onnxscript.optimizer
88
from onnxscript.rewriter.ort_fusions._smollm_1 import TestData
99
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
10-
from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_normalization
10+
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
11+
from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_skip_rms_normalization
1112

1213

1314
class TestSkipNormalization(unittest.TestCase):
@@ -17,7 +18,8 @@ def test_smollm(self):
1718
onnxscript.optimizer.optimize(model)
1819
inputs = smollm_test.get_ort_inputs()
1920
original_outputs = ort_run("original", model, inputs)
20-
fuse_normalization(model)
21+
fuse_rms_normalization(model)
22+
fuse_skip_rms_normalization(model)
2123
op_types = [n.op_type for n in model.graph]
2224
self.assertIn("SkipSimplifiedLayerNormalization", op_types)
2325
new_outputs = ort_run("optimized", model, inputs)

0 commit comments

Comments
 (0)