@@ -129,38 +129,70 @@ def pattern(self, op, input, skip, gamma, beta, bias, epsilon, stash_type):
129
129
)
130
130
return normalized , skip_sum
131
131
132
+ def check (
133
+ self , op , input , skip , gamma , beta , bias , epsilon , stash_type
134
+ ) -> pattern .MatchResult : # type: ignore[name-defined]
135
+ """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
136
+ check_result = pattern .MatchResult ()
137
+ bindings : dict [str , Dim ] = {}
132
138
133
- def _skip_layer_normalization_add_bias (
134
- op , input , skip , gamma , beta , bias , epsilon , stash_type
135
- ):
136
- normalized , _mean , _inv_std_var , skip_sum = op .SkipLayerNormalization (
137
- input ,
138
- skip ,
139
- gamma ,
140
- beta ,
141
- bias ,
142
- epsilon = epsilon ,
143
- _outputs = 4 ,
144
- _domain = "com.microsoft" ,
145
- )
146
- return normalized , skip_sum
139
+ def no_match (val : ir .Value , dims : Sequence [str ]) -> bool :
140
+ return not _fusion_utils ._check_shape (bindings , val , dims )
141
+
142
+ if no_match (input , ["B" , "S" , "D" ]):
143
+ return check_result .fail (
144
+ f"Shape mismatch: { input } does not match expected dimensions ['B', 'S', 'D']" ,
145
+ input ,
146
+ )
147
+ if no_match (skip , ["B" , "S" , "D" ]):
148
+ return check_result .fail (
149
+ f"Shape mismatch: { skip } does not match expected dimensions ['B', 'S', 'D']" ,
150
+ skip ,
151
+ )
152
+ if no_match (gamma , ["D" ]):
153
+ return check_result .fail (
154
+ f"Shape mismatch: { gamma } does not match expected dimensions ['D']" ,
155
+ gamma ,
156
+ )
157
+ if no_match (beta , ["D" ]):
158
+ return check_result .fail (
159
+ f"Shape mismatch: { beta } does not match expected dimensions ['D']" ,
160
+ beta ,
161
+ )
162
+ if self ._has_bias :
163
+ if no_match (bias , ["D" ]):
164
+ return check_result .fail (
165
+ f"Shape mismatch: { bias } does not match expected dimensions ['D']" ,
166
+ bias ,
167
+ )
168
+
169
+ return check_result
170
+
171
+ def rewrite (self , op , input , skip , gamma , beta , bias , epsilon , stash_type ):
172
+ normalized , _mean , _inv_std_var , skip_sum = op .SkipLayerNormalization (
173
+ input ,
174
+ skip ,
175
+ gamma ,
176
+ beta ,
177
+ bias ,
178
+ epsilon = epsilon ,
179
+ _outputs = 4 ,
180
+ _domain = "com.microsoft" ,
181
+ )
182
+ return normalized , skip_sum
147
183
148
184
149
185
_skip_layer_add_bias_rule = SkipLayerNormFusion .rule (
150
186
"SkipLayerNormBias" , has_bias = True , bias_pre_add = False
151
187
)
152
- _skip_layer_add_bias_rule = pattern .RewriteRule (
153
- _skip_layer_norm_add_bias_pattern ,
154
- _skip_layer_normalization_add_bias ,
155
- name = "SkipLayerNormAddBias" ,
188
+ _skip_layer_pre_add_bias_rule = SkipLayerNormFusion .rule (
189
+ "SkipLayerNormPreBias" , has_bias = True , bias_pre_add = True
156
190
)
191
+ _skip_layer_rule = SkipLayerNormFusion .rule ("SkipLayerNorm" , has_bias = False )
157
192
158
-
159
- skip_layer_normalization_rules = [_skip_layer_rule , _skip_layer_add_bias_rule ]
160
- skip_layer_normalization_ruleset = pattern .RewriteRuleSet (skip_layer_normalization_rules )
161
-
162
-
163
- fuse_skip_rms_normalization = _fusion_utils .apply_fusion_rules (skip_rms_normalization_ruleset )
193
+ skip_layer_normalization_ruleset = pattern .RewriteRuleSet (
194
+ [_skip_layer_pre_add_bias_rule , _skip_layer_add_bias_rule , _skip_layer_rule ]
195
+ )
164
196
165
197
166
198
fuse_skip_layer_normalization = _fusion_utils .apply_fusion_rules (
0 commit comments