Skip to content

Commit 1d7ad44

Browse files
rebase changes
1 parent 6d36150 commit 1d7ad44

File tree

1 file changed

+56
-24
lines changed

1 file changed

+56
-24
lines changed

onnxscript/rewriter/ort_fusions/skip_normalization.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -129,38 +129,70 @@ def pattern(self, op, input, skip, gamma, beta, bias, epsilon, stash_type):
129129
)
130130
return normalized, skip_sum
131131

132+
def check(
133+
self, op, input, skip, gamma, beta, bias, epsilon, stash_type
134+
) -> pattern.MatchResult: # type: ignore[name-defined]
135+
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
136+
check_result = pattern.MatchResult()
137+
bindings: dict[str, Dim] = {}
132138

133-
def _skip_layer_normalization_add_bias(
134-
op, input, skip, gamma, beta, bias, epsilon, stash_type
135-
):
136-
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
137-
input,
138-
skip,
139-
gamma,
140-
beta,
141-
bias,
142-
epsilon=epsilon,
143-
_outputs=4,
144-
_domain="com.microsoft",
145-
)
146-
return normalized, skip_sum
139+
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
140+
return not _fusion_utils._check_shape(bindings, val, dims)
141+
142+
if no_match(input, ["B", "S", "D"]):
143+
return check_result.fail(
144+
f"Shape mismatch: {input} does not match expected dimensions ['B', 'S', 'D']",
145+
input,
146+
)
147+
if no_match(skip, ["B", "S", "D"]):
148+
return check_result.fail(
149+
f"Shape mismatch: {skip} does not match expected dimensions ['B', 'S', 'D']",
150+
skip,
151+
)
152+
if no_match(gamma, ["D"]):
153+
return check_result.fail(
154+
f"Shape mismatch: {gamma} does not match expected dimensions ['D']",
155+
gamma,
156+
)
157+
if no_match(beta, ["D"]):
158+
return check_result.fail(
159+
f"Shape mismatch: {beta} does not match expected dimensions ['D']",
160+
beta,
161+
)
162+
if self._has_bias:
163+
if no_match(bias, ["D"]):
164+
return check_result.fail(
165+
f"Shape mismatch: {bias} does not match expected dimensions ['D']",
166+
bias,
167+
)
168+
169+
return check_result
170+
171+
def rewrite(self, op, input, skip, gamma, beta, bias, epsilon, stash_type):
172+
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
173+
input,
174+
skip,
175+
gamma,
176+
beta,
177+
bias,
178+
epsilon=epsilon,
179+
_outputs=4,
180+
_domain="com.microsoft",
181+
)
182+
return normalized, skip_sum
147183

148184

149185
_skip_layer_add_bias_rule = SkipLayerNormFusion.rule(
150186
"SkipLayerNormBias", has_bias=True, bias_pre_add=False
151187
)
152-
_skip_layer_add_bias_rule = pattern.RewriteRule(
153-
_skip_layer_norm_add_bias_pattern,
154-
_skip_layer_normalization_add_bias,
155-
name="SkipLayerNormAddBias",
188+
_skip_layer_pre_add_bias_rule = SkipLayerNormFusion.rule(
189+
"SkipLayerNormPreBias", has_bias=True, bias_pre_add=True
156190
)
191+
_skip_layer_rule = SkipLayerNormFusion.rule("SkipLayerNorm", has_bias=False)
157192

158-
159-
skip_layer_normalization_rules = [_skip_layer_rule, _skip_layer_add_bias_rule]
160-
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules)
161-
162-
163-
fuse_skip_rms_normalization = _fusion_utils.apply_fusion_rules(skip_rms_normalization_ruleset)
193+
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(
194+
[_skip_layer_pre_add_bias_rule, _skip_layer_add_bias_rule, _skip_layer_rule]
195+
)
164196

165197

166198
fuse_skip_layer_normalization = _fusion_utils.apply_fusion_rules(

0 commit comments

Comments
 (0)