|
2 | 2 | # Licensed under the MIT License.
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 |
| -from onnxscript.rewriter import _fusion_utils, pattern |
| 5 | +from typing import Sequence, Union |
6 | 6 |
|
| 7 | +import onnxscript.ir as ir |
| 8 | +from onnxscript.rewriter import _fusion_utils, pattern |
7 | 9 |
|
8 |
| -def _skip_rms_norm_pattern(op, input, skip, gamma, epsilon, stash_type): |
9 |
| - skip_sum = op.Add(input, skip) |
10 |
| - normalized = op.SimplifiedLayerNormalization( |
11 |
| - skip_sum, |
12 |
| - gamma, |
13 |
| - axis=-1, |
14 |
| - epsilon=epsilon, |
15 |
| - stash_type=stash_type, |
16 |
| - ) |
17 |
| - return normalized, skip_sum |
18 |
| - |
19 |
| - |
20 |
| -def _skip_rms_normalization(op, input, skip, gamma, epsilon, stash_type): |
21 |
| - if stash_type.value != 1: # FLOAT type |
22 |
| - return None |
23 |
| - normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( |
24 |
| - input, |
25 |
| - skip, |
26 |
| - gamma, |
27 |
| - epsilon=epsilon, |
28 |
| - _outputs=4, |
29 |
| - _domain="com.microsoft", |
30 |
| - ) |
31 |
| - return normalized, skip_sum |
32 |
| - |
33 |
| - |
34 |
| -_skip_rms_rule = pattern.RewriteRule(_skip_rms_norm_pattern, _skip_rms_normalization) |
35 |
| - |
36 |
| -skip_rms_normalization_rules = [_skip_rms_rule] |
37 |
| -skip_rms_normalization_ruleset = pattern.RewriteRuleSet(skip_rms_normalization_rules) |
38 |
| - |
39 |
| - |
40 |
| -def _skip_layer_norm_pattern(op, input, skip, gamma, beta, epsilon, stash_type): |
41 |
| - skip_sum = op.Add(input, skip) |
42 |
| - normalized = op.LayerNormalization( |
43 |
| - skip_sum, |
44 |
| - gamma, |
45 |
| - beta, |
46 |
| - axis=-1, |
47 |
| - epsilon=epsilon, |
48 |
| - stash_type=stash_type, |
49 |
| - ) |
50 |
| - return normalized, skip_sum |
51 |
| - |
52 |
| - |
53 |
| -def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type): |
54 |
| - if stash_type.value != 1: # FLOAT type |
55 |
| - return None |
56 |
| - normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( |
57 |
| - input, |
58 |
| - skip, |
59 |
| - gamma, |
60 |
| - beta, |
61 |
| - epsilon=epsilon, |
62 |
| - _outputs=4, |
63 |
| - _domain="com.microsoft", |
64 |
| - ) |
65 |
| - return normalized, skip_sum |
66 |
| - |
67 |
| - |
68 |
| -# Fusion rule for Add + SkipLayerNormalization |
69 |
| -def _skip_layer_norm_add_bias_pattern(op, input, skip, gamma, beta, bias, epsilon, stash_type): |
70 |
| - bias_sum = op.Add(input, bias) |
71 |
| - normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( |
72 |
| - bias_sum, |
73 |
| - skip, |
74 |
| - gamma, |
75 |
| - beta, |
76 |
| - epsilon=epsilon, |
77 |
| - _outputs=4, |
78 |
| - _domain="com.microsoft", |
79 |
| - ) |
80 |
| - return normalized, skip_sum |
81 |
| - |
82 |
| - |
83 |
| -def _skip_layer_normalization_add_bias( |
84 |
| - op, input, skip, gamma, beta, bias, epsilon, stash_type |
85 |
| -): |
86 |
| - normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( |
87 |
| - input, |
88 |
| - skip, |
89 |
| - gamma, |
90 |
| - beta, |
91 |
| - bias, |
92 |
| - epsilon=epsilon, |
93 |
| - _outputs=4, |
94 |
| - _domain="com.microsoft", |
95 |
| - ) |
96 |
| - return normalized, skip_sum |
97 |
| - |
98 |
| - |
99 |
| -_skip_layer_rule = pattern.RewriteRule( |
100 |
| - _skip_layer_norm_pattern, _skip_layer_normalization, name="SkipLayerNorm" |
| 10 | +Dim = Union[int, ir.SymbolicDim] |
| 11 | + |
| 12 | +# Fusion rule for SkipRMSNormalization |
| 13 | + |
| 14 | + |
| 15 | +class SkipRmsNormFusion(pattern.RewriteRuleClassBase): |
| 16 | + def __init__(self, name: str, has_bias: bool = False, bias_pre_add: bool = False): |
| 17 | + """Fusion rule for SkipRMSNormalization.""" |
| 18 | + super().__init__(name=name) |
| 19 | + self._has_bias = has_bias |
| 20 | + self._bias_pre_add = bias_pre_add |
| 21 | + |
| 22 | + def pattern(self, op, input, skip, gamma, bias, epsilon, stash_type): |
| 23 | + if self._has_bias and self._bias_pre_add: |
| 24 | + input = op.Add(input, bias) |
| 25 | + skip_sum = op.Add(input, skip) |
| 26 | + if self._has_bias and not self._bias_pre_add: |
| 27 | + skip_sum = op.Add(skip_sum, bias) |
| 28 | + # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. |
| 29 | + # No need to use com.microsoft domain here; but this is a custom op in ORT. |
| 30 | + normalized = op.SimplifiedLayerNormalization( |
| 31 | + skip_sum, |
| 32 | + gamma, |
| 33 | + axis=-1, |
| 34 | + epsilon=epsilon, |
| 35 | + stash_type=stash_type, |
| 36 | + ) |
| 37 | + return normalized, skip_sum |
| 38 | + |
| 39 | + def check(self, op, input, skip, gamma, bias, epsilon, stash_type) -> pattern.MatchResult: # type: ignore[name-defined] |
| 40 | + """Check if the pattern matches conditions for use of SkipSimplifiedLayerNormalization op.""" |
| 41 | + check_result = pattern.MatchResult() |
| 42 | + bindings: dict[str, Dim] = {} |
| 43 | + |
| 44 | + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: |
| 45 | + return not _fusion_utils._check_shape(bindings, val, dims) |
| 46 | + |
| 47 | + if no_match(input, ["B", "S", "D"]): |
| 48 | + return check_result.fail( |
| 49 | + f"Shape mismatch: {input} does not match expected dimensions ['B', 'S', 'D']", |
| 50 | + input, |
| 51 | + ) |
| 52 | + if no_match(skip, ["B", "S", "D"]): |
| 53 | + return check_result.fail( |
| 54 | + f"Shape mismatch: {skip} does not match expected dimensions ['B', 'S', 'D']", |
| 55 | + skip, |
| 56 | + ) |
| 57 | + if no_match(gamma, ["D"]): |
| 58 | + return check_result.fail( |
| 59 | + f"Shape mismatch: {gamma} does not match expected dimensions ['D']", |
| 60 | + gamma, |
| 61 | + ) |
| 62 | + if self._has_bias: |
| 63 | + if no_match(bias, ["D"]): |
| 64 | + return check_result.fail( |
| 65 | + f"Shape mismatch: {bias} does not match expected dimensions ['D']", |
| 66 | + bias, |
| 67 | + ) |
| 68 | + |
| 69 | + return check_result |
| 70 | + |
| 71 | + def rewrite(self, op, input, skip, gamma, bias, epsilon, stash_type): |
| 72 | + if self._has_bias: |
| 73 | + normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( |
| 74 | + input, |
| 75 | + skip, |
| 76 | + gamma, |
| 77 | + bias, |
| 78 | + epsilon=epsilon, |
| 79 | + _outputs=4, |
| 80 | + _domain="com.microsoft", |
| 81 | + ) |
| 82 | + else: |
| 83 | + normalized, _mean, _inv_std_var, skip_sum = op.SkipSimplifiedLayerNormalization( |
| 84 | + input, |
| 85 | + skip, |
| 86 | + gamma, |
| 87 | + epsilon=epsilon, |
| 88 | + _outputs=4, |
| 89 | + _domain="com.microsoft", |
| 90 | + ) |
| 91 | + return normalized, skip_sum |
| 92 | + |
| 93 | + |
| 94 | +_skip_rms_add_bias_rule = SkipRmsNormFusion.rule( |
| 95 | + "SkipRmsNormBias", has_bias=True, bias_pre_add=False |
101 | 96 | )
|
102 |
| -_skip_layer_add_bias_rule = pattern.RewriteRule( |
103 |
| - _skip_layer_norm_add_bias_pattern, |
104 |
| - _skip_layer_normalization_add_bias, |
105 |
| - name="SkipLayerNormAddBias", |
| 97 | +_skip_rms_pre_add_bias_rule = SkipRmsNormFusion.rule( |
| 98 | + "SkipRmsNormPreBias", has_bias=True, bias_pre_add=True |
106 | 99 | )
|
| 100 | +_skip_rms_rule = SkipRmsNormFusion.rule("SkipRmsNorm", has_bias=False) |
107 | 101 |
|
| 102 | +skip_rms_normalization_ruleset = pattern.RewriteRuleSet( |
| 103 | + [_skip_rms_pre_add_bias_rule, _skip_rms_add_bias_rule, _skip_rms_rule] |
| 104 | +) |
| 105 | +fuse_skip_rms_normalization = _fusion_utils.apply_fusion_rules(skip_rms_normalization_ruleset) |
108 | 106 |
|
109 |
| -skip_layer_normalization_rules = [_skip_layer_rule, _skip_layer_add_bias_rule] |
110 |
| -skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules) |
111 | 107 |
|
| 108 | +# Fusion rule for SkipLayerNormalization |
| 109 | +class SkipLayerNormFusion(pattern.RewriteRuleClassBase): |
| 110 | + def __init__(self, name: str, has_bias: bool = False, bias_pre_add: bool = False): |
| 111 | + """Fusion rule for SkipLayerNormalization.""" |
| 112 | + super().__init__(name=name) |
| 113 | + self._has_bias = has_bias |
| 114 | + self._bias_pre_add = bias_pre_add |
| 115 | + |
| 116 | + def pattern(self, op, input, skip, gamma, beta, bias, epsilon, stash_type): |
| 117 | + if self._has_bias and self._bias_pre_add: |
| 118 | + input = op.Add(input, bias) |
| 119 | + skip_sum = op.Add(input, skip) |
| 120 | + if self._has_bias and not self._bias_pre_add: |
| 121 | + skip_sum = op.Add(skip_sum, bias) |
| 122 | + normalized = op.LayerNormalization( |
| 123 | + skip_sum, |
| 124 | + gamma, |
| 125 | + beta, |
| 126 | + axis=-1, |
| 127 | + epsilon=epsilon, |
| 128 | + stash_type=stash_type, |
| 129 | + ) |
| 130 | + return normalized, skip_sum |
| 131 | + |
| 132 | + def check( |
| 133 | + self, op, input, skip, gamma, beta, bias, epsilon, stash_type |
| 134 | + ) -> pattern.MatchResult: # type: ignore[name-defined] |
| 135 | + """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" |
| 136 | + check_result = pattern.MatchResult() |
| 137 | + bindings: dict[str, Dim] = {} |
| 138 | + |
| 139 | + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: |
| 140 | + return not _fusion_utils._check_shape(bindings, val, dims) |
| 141 | + |
| 142 | + if no_match(input, ["B", "S", "D"]): |
| 143 | + return check_result.fail( |
| 144 | + f"Shape mismatch: {input} does not match expected dimensions ['B', 'S', 'D']", |
| 145 | + input, |
| 146 | + ) |
| 147 | + if no_match(skip, ["B", "S", "D"]): |
| 148 | + return check_result.fail( |
| 149 | + f"Shape mismatch: {skip} does not match expected dimensions ['B', 'S', 'D']", |
| 150 | + skip, |
| 151 | + ) |
| 152 | + if no_match(gamma, ["D"]): |
| 153 | + return check_result.fail( |
| 154 | + f"Shape mismatch: {gamma} does not match expected dimensions ['D']", |
| 155 | + gamma, |
| 156 | + ) |
| 157 | + if no_match(beta, ["D"]): |
| 158 | + return check_result.fail( |
| 159 | + f"Shape mismatch: {beta} does not match expected dimensions ['D']", |
| 160 | + beta, |
| 161 | + ) |
| 162 | + if self._has_bias: |
| 163 | + if no_match(bias, ["D"]): |
| 164 | + return check_result.fail( |
| 165 | + f"Shape mismatch: {bias} does not match expected dimensions ['D']", |
| 166 | + bias, |
| 167 | + ) |
| 168 | + |
| 169 | + return check_result |
| 170 | + |
| 171 | + def rewrite(self, op, input, skip, gamma, beta, bias, epsilon, stash_type): |
| 172 | + normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( |
| 173 | + input, |
| 174 | + skip, |
| 175 | + gamma, |
| 176 | + beta, |
| 177 | + bias, |
| 178 | + epsilon=epsilon, |
| 179 | + _outputs=4, |
| 180 | + _domain="com.microsoft", |
| 181 | + ) |
| 182 | + return normalized, skip_sum |
| 183 | + |
| 184 | + |
| 185 | +_skip_layer_add_bias_rule = SkipLayerNormFusion.rule( |
| 186 | + "SkipLayerNormBias", has_bias=True, bias_pre_add=False |
| 187 | +) |
| 188 | +_skip_layer_pre_add_bias_rule = SkipLayerNormFusion.rule( |
| 189 | + "SkipLayerNormPreBias", has_bias=True, bias_pre_add=True |
| 190 | +) |
| 191 | +_skip_layer_rule = SkipLayerNormFusion.rule("SkipLayerNorm", has_bias=False) |
112 | 192 |
|
113 |
| -fuse_skip_rms_normalization = _fusion_utils.apply_fusion_rules(skip_rms_normalization_ruleset) |
| 193 | +skip_layer_normalization_ruleset = pattern.RewriteRuleSet( |
| 194 | + [_skip_layer_pre_add_bias_rule, _skip_layer_add_bias_rule, _skip_layer_rule] |
| 195 | +) |
114 | 196 |
|
115 | 197 |
|
116 | 198 | fuse_skip_layer_normalization = _fusion_utils.apply_fusion_rules(
|
|
0 commit comments