Skip to content

Commit 34e7ba8

Browse files
Rewrite Skip fusions with check functions (#2259)
- Rewrite SkipLayerNorm fusions and SkipRMSNorm fusions to match format of other ort-fusion patterns. - Added check functions for ensuring shapes are as expected. - Moving these fusions out of PR #2221 Fusion support patterns with: - `Add(input, skip) -> Norm` - `Add(input, skip) -> Add (result, bias) -> Norm` - `Add(input, bias) -> Add (result, skip) -> Norm` NOTE: These fusions should support: - Planned whisper-related optimizations - Benchmark failures stemming from wrong bias shapes for SkipLayerNorm fusions
1 parent 2766661 commit 34e7ba8

File tree

1 file changed

+183
-101
lines changed

1 file changed

+183
-101
lines changed

onnxscript/rewriter/ort_fusions/skip_normalization.py

Lines changed: 183 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,115 +2,197 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
from onnxscript.rewriter import _fusion_utils, pattern
5+
from typing import Sequence, Union
66

7+
import onnxscript.ir as ir
8+
from onnxscript.rewriter import _fusion_utils, pattern
79

8-
def _skip_rms_norm_pattern(op, input, skip, gamma, epsilon, stash_type):
9-
skip_sum = op.Add(input, skip)
10-
normalized = op.SimplifiedLayerNormalization(
11-
skip_sum,
12-
gamma,
13-
axis=-1,
14-
epsilon=epsilon,
15-
stash_type=stash_type,
16-
)
17-
return normalized, skip_sum
18-
19-
20-
def _skip_rms_normalization(op, input, skip, gamma, epsilon, stash_type):
21-
if stash_type.value != 1: # FLOAT type
22-
return None
23-
normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization(
24-
input,
25-
skip,
26-
gamma,
27-
epsilon=epsilon,
28-
_outputs=4,
29-
_domain="com.microsoft",
30-
)
31-
return normalized, skip_sum
32-
33-
34-
_skip_rms_rule = pattern.RewriteRule(_skip_rms_norm_pattern, _skip_rms_normalization)
35-
36-
skip_rms_normalization_rules = [_skip_rms_rule]
37-
skip_rms_normalization_ruleset = pattern.RewriteRuleSet(skip_rms_normalization_rules)
38-
39-
40-
def _skip_layer_norm_pattern(op, input, skip, gamma, beta, epsilon, stash_type):
41-
skip_sum = op.Add(input, skip)
42-
normalized = op.LayerNormalization(
43-
skip_sum,
44-
gamma,
45-
beta,
46-
axis=-1,
47-
epsilon=epsilon,
48-
stash_type=stash_type,
49-
)
50-
return normalized, skip_sum
51-
52-
53-
def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type):
54-
if stash_type.value != 1: # FLOAT type
55-
return None
56-
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
57-
input,
58-
skip,
59-
gamma,
60-
beta,
61-
epsilon=epsilon,
62-
_outputs=4,
63-
_domain="com.microsoft",
64-
)
65-
return normalized, skip_sum
66-
67-
68-
# Fusion rule for Add + SkipLayerNormalization
69-
def _skip_layer_norm_add_bias_pattern(op, input, skip, gamma, beta, bias, epsilon, stash_type):
70-
bias_sum = op.Add(input, bias)
71-
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
72-
bias_sum,
73-
skip,
74-
gamma,
75-
beta,
76-
epsilon=epsilon,
77-
_outputs=4,
78-
_domain="com.microsoft",
79-
)
80-
return normalized, skip_sum
81-
82-
83-
def _skip_layer_normalization_add_bias(
84-
op, input, skip, gamma, beta, bias, epsilon, stash_type
85-
):
86-
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
87-
input,
88-
skip,
89-
gamma,
90-
beta,
91-
bias,
92-
epsilon=epsilon,
93-
_outputs=4,
94-
_domain="com.microsoft",
95-
)
96-
return normalized, skip_sum
97-
98-
99-
_skip_layer_rule = pattern.RewriteRule(
100-
_skip_layer_norm_pattern, _skip_layer_normalization, name="SkipLayerNorm"
10+
Dim = Union[int, ir.SymbolicDim]
11+
12+
# Fusion rule for SkipRMSNormalization
13+
14+
15+
class SkipRmsNormFusion(pattern.RewriteRuleClassBase):
16+
def __init__(self, name: str, has_bias: bool = False, bias_pre_add: bool = False):
17+
"""Fusion rule for SkipRMSNormalization."""
18+
super().__init__(name=name)
19+
self._has_bias = has_bias
20+
self._bias_pre_add = bias_pre_add
21+
22+
def pattern(self, op, input, skip, gamma, bias, epsilon, stash_type):
23+
if self._has_bias and self._bias_pre_add:
24+
input = op.Add(input, bias)
25+
skip_sum = op.Add(input, skip)
26+
if self._has_bias and not self._bias_pre_add:
27+
skip_sum = op.Add(skip_sum, bias)
28+
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
29+
# No need to use com.microsoft domain here; but this is a custom op in ORT.
30+
normalized = op.SimplifiedLayerNormalization(
31+
skip_sum,
32+
gamma,
33+
axis=-1,
34+
epsilon=epsilon,
35+
stash_type=stash_type,
36+
)
37+
return normalized, skip_sum
38+
39+
def check(self, op, input, skip, gamma, bias, epsilon, stash_type) -> pattern.MatchResult: # type: ignore[name-defined]
40+
"""Check if the pattern matches conditions for use of SkipSimplifiedLayerNormalization op."""
41+
check_result = pattern.MatchResult()
42+
bindings: dict[str, Dim] = {}
43+
44+
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
45+
return not _fusion_utils._check_shape(bindings, val, dims)
46+
47+
if no_match(input, ["B", "S", "D"]):
48+
return check_result.fail(
49+
f"Shape mismatch: {input} does not match expected dimensions ['B', 'S', 'D']",
50+
input,
51+
)
52+
if no_match(skip, ["B", "S", "D"]):
53+
return check_result.fail(
54+
f"Shape mismatch: {skip} does not match expected dimensions ['B', 'S', 'D']",
55+
skip,
56+
)
57+
if no_match(gamma, ["D"]):
58+
return check_result.fail(
59+
f"Shape mismatch: {gamma} does not match expected dimensions ['D']",
60+
gamma,
61+
)
62+
if self._has_bias:
63+
if no_match(bias, ["D"]):
64+
return check_result.fail(
65+
f"Shape mismatch: {bias} does not match expected dimensions ['D']",
66+
bias,
67+
)
68+
69+
return check_result
70+
71+
def rewrite(self, op, input, skip, gamma, bias, epsilon, stash_type):
72+
if self._has_bias:
73+
normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization(
74+
input,
75+
skip,
76+
gamma,
77+
bias,
78+
epsilon=epsilon,
79+
_outputs=4,
80+
_domain="com.microsoft",
81+
)
82+
else:
83+
normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization(
84+
input,
85+
skip,
86+
gamma,
87+
epsilon=epsilon,
88+
_outputs=4,
89+
_domain="com.microsoft",
90+
)
91+
return normalized, skip_sum
92+
93+
94+
_skip_rms_add_bias_rule = SkipRmsNormFusion.rule(
95+
"SkipRmsNormBias", has_bias=True, bias_pre_add=False
10196
)
102-
_skip_layer_add_bias_rule = pattern.RewriteRule(
103-
_skip_layer_norm_add_bias_pattern,
104-
_skip_layer_normalization_add_bias,
105-
name="SkipLayerNormAddBias",
97+
_skip_rms_pre_add_bias_rule = SkipRmsNormFusion.rule(
98+
"SkipRmsNormPreBias", has_bias=True, bias_pre_add=True
10699
)
100+
_skip_rms_rule = SkipRmsNormFusion.rule("SkipRmsNorm", has_bias=False)
107101

102+
skip_rms_normalization_ruleset = pattern.RewriteRuleSet(
103+
[_skip_rms_pre_add_bias_rule, _skip_rms_add_bias_rule, _skip_rms_rule]
104+
)
105+
fuse_skip_rms_normalization = _fusion_utils.apply_fusion_rules(skip_rms_normalization_ruleset)
108106

109-
skip_layer_normalization_rules = [_skip_layer_rule, _skip_layer_add_bias_rule]
110-
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules)
111107

108+
# Fusion rule for SkipLayerNormalization
109+
class SkipLayerNormFusion(pattern.RewriteRuleClassBase):
110+
def __init__(self, name: str, has_bias: bool = False, bias_pre_add: bool = False):
111+
"""Fusion rule for SkipLayerNormalization."""
112+
super().__init__(name=name)
113+
self._has_bias = has_bias
114+
self._bias_pre_add = bias_pre_add
115+
116+
def pattern(self, op, input, skip, gamma, beta, bias, epsilon, stash_type):
117+
if self._has_bias and self._bias_pre_add:
118+
input = op.Add(input, bias)
119+
skip_sum = op.Add(input, skip)
120+
if self._has_bias and not self._bias_pre_add:
121+
skip_sum = op.Add(skip_sum, bias)
122+
normalized = op.LayerNormalization(
123+
skip_sum,
124+
gamma,
125+
beta,
126+
axis=-1,
127+
epsilon=epsilon,
128+
stash_type=stash_type,
129+
)
130+
return normalized, skip_sum
131+
132+
def check(
133+
self, op, input, skip, gamma, beta, bias, epsilon, stash_type
134+
) -> pattern.MatchResult: # type: ignore[name-defined]
135+
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
136+
check_result = pattern.MatchResult()
137+
bindings: dict[str, Dim] = {}
138+
139+
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
140+
return not _fusion_utils._check_shape(bindings, val, dims)
141+
142+
if no_match(input, ["B", "S", "D"]):
143+
return check_result.fail(
144+
f"Shape mismatch: {input} does not match expected dimensions ['B', 'S', 'D']",
145+
input,
146+
)
147+
if no_match(skip, ["B", "S", "D"]):
148+
return check_result.fail(
149+
f"Shape mismatch: {skip} does not match expected dimensions ['B', 'S', 'D']",
150+
skip,
151+
)
152+
if no_match(gamma, ["D"]):
153+
return check_result.fail(
154+
f"Shape mismatch: {gamma} does not match expected dimensions ['D']",
155+
gamma,
156+
)
157+
if no_match(beta, ["D"]):
158+
return check_result.fail(
159+
f"Shape mismatch: {beta} does not match expected dimensions ['D']",
160+
beta,
161+
)
162+
if self._has_bias:
163+
if no_match(bias, ["D"]):
164+
return check_result.fail(
165+
f"Shape mismatch: {bias} does not match expected dimensions ['D']",
166+
bias,
167+
)
168+
169+
return check_result
170+
171+
def rewrite(self, op, input, skip, gamma, beta, bias, epsilon, stash_type):
172+
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
173+
input,
174+
skip,
175+
gamma,
176+
beta,
177+
bias,
178+
epsilon=epsilon,
179+
_outputs=4,
180+
_domain="com.microsoft",
181+
)
182+
return normalized, skip_sum
183+
184+
185+
_skip_layer_add_bias_rule = SkipLayerNormFusion.rule(
186+
"SkipLayerNormBias", has_bias=True, bias_pre_add=False
187+
)
188+
_skip_layer_pre_add_bias_rule = SkipLayerNormFusion.rule(
189+
"SkipLayerNormPreBias", has_bias=True, bias_pre_add=True
190+
)
191+
_skip_layer_rule = SkipLayerNormFusion.rule("SkipLayerNorm", has_bias=False)
112192

113-
fuse_skip_rms_normalization = _fusion_utils.apply_fusion_rules(skip_rms_normalization_ruleset)
193+
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(
194+
[_skip_layer_pre_add_bias_rule, _skip_layer_add_bias_rule, _skip_layer_rule]
195+
)
114196

115197

116198
fuse_skip_layer_normalization = _fusion_utils.apply_fusion_rules(

0 commit comments

Comments
 (0)