diff --git a/onnxscript/rewriter/ort_fusions/skip_normalization.py b/onnxscript/rewriter/ort_fusions/skip_normalization.py index 383e0eb99b..ee6e366608 100644 --- a/onnxscript/rewriter/ort_fusions/skip_normalization.py +++ b/onnxscript/rewriter/ort_fusions/skip_normalization.py @@ -22,7 +22,15 @@ def __init__(self, name: str, has_bias: bool = False, bias_pre_add: bool = False def pattern(self, op, input, skip, gamma, bias, epsilon, stash_type): if self._has_bias and self._bias_pre_add: input = op.Add(input, bias) - skip_sum = op.Add(input, skip) + + # Support different combinations of addition of input and skip + skip_sum_pattern_1 = op.Add(skip, input) + skip_sum_pattern_2 = op.Add(input, skip) + skip_sum = pattern.OrValue( + [skip_sum_pattern_1, skip_sum_pattern_2], + name="skip_sum", + ) + if self._has_bias and not self._bias_pre_add: skip_sum = op.Add(skip_sum, bias) # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. @@ -36,7 +44,17 @@ def pattern(self, op, input, skip, gamma, bias, epsilon, stash_type): ) return normalized, skip_sum - def check(self, op, input, skip, gamma, bias, epsilon, stash_type) -> pattern.MatchResult: # type: ignore[name-defined] + def check( + self, + op, + input, + skip, + gamma, + bias, + epsilon, + stash_type, + **_, + ) -> pattern.MatchResult: # type: ignore[name-defined] """Check if the pattern matches conditions for use of SkipSimplifiedLayerNormalization op.""" check_result = pattern.MatchResult() bindings: dict[str, Dim] = {} @@ -68,7 +86,17 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: return check_result - def rewrite(self, op, input, skip, gamma, bias, epsilon, stash_type): + def rewrite( + self, + op, + input, + skip, + gamma, + bias, + epsilon, + stash_type, + **_, + ): if self._has_bias: normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( input, @@ -116,7 +144,12 @@ def __init__(self, name: str, has_bias: bool = False, bias_pre_add: bool = False def pattern(self, op, input, skip, gamma, beta, bias, epsilon, stash_type): if self._has_bias and self._bias_pre_add: input = op.Add(input, bias) - skip_sum = op.Add(input, skip) + + # Support different combinations of addition of input and skip + skip_sum_pattern_1 = op.Add(skip, input) + skip_sum_pattern_2 = op.Add(input, skip) + skip_sum = pattern.OrValue([skip_sum_pattern_1, skip_sum_pattern_2], name="skip_sum") + if self._has_bias and not self._bias_pre_add: skip_sum = op.Add(skip_sum, bias) normalized = op.LayerNormalization( @@ -130,7 +163,16 @@ def pattern(self, op, input, skip, gamma, beta, bias, epsilon, stash_type): return normalized, skip_sum def check( - self, op, input, skip, gamma, beta, bias, epsilon, stash_type + self, + op, + input, + skip, + gamma, + beta, + bias, + epsilon, + stash_type, + **_, ) -> pattern.MatchResult: # type: ignore[name-defined] """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" check_result = pattern.MatchResult() @@ -168,7 +210,18 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: return check_result - def rewrite(self, op, input, skip, gamma, beta, bias, epsilon, stash_type): + def rewrite( + self, + op, + input, + skip, + gamma, + beta, + bias, + epsilon, + stash_type, + **_, + ): normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( input, skip,