|
2 | 2 | # Licensed under the MIT License.
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 |
| -from onnxscript.rewriter import _fusion_utils, pattern |
| 5 | +import math |
| 6 | + |
| 7 | +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern |
6 | 8 |
|
7 | 9 |
|
8 | 10 | class SDPA(pattern.RewriteRuleClassBase):
|
@@ -166,34 +168,23 @@ def rewrite(
|
166 | 168 | return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion")
|
167 | 169 |
|
168 | 170 |
|
169 |
| -# Rules for SDPA without mask |
170 |
| -unmasked_pre_div_sdpa_rule = SDPA.rule( |
171 |
| - "unmasked_pre_div_sdpa", use_mask=False, pre_scale=True, use_mul=False |
172 |
| -) |
173 |
| -unmasked_pre_mul_sdpa_rule = SDPA.rule( |
174 |
| - "unmasked_pre_mul_sdpa", use_mask=False, pre_scale=True, use_mul=True |
175 |
| -) |
176 |
| -unmasked_post_div_sdpa_rule = SDPA.rule( |
177 |
| - "unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=False |
178 |
| -) |
179 |
| -unmasked_post_mul_sdpa_rule = SDPA.rule( |
180 |
| - "unmasked_post_mul_sdpa", use_mask=False, pre_scale=False, use_mul=True |
181 |
| -) |
182 |
| - |
183 |
| -# Rules for SDPA with mask |
184 |
| -masked_pre_div_sdpa_rule = SDPA.rule( |
185 |
| - "masked_pre_div_sdpa", use_mask=True, pre_scale=True, use_mul=False |
186 |
| -) |
187 |
| -masked_pre_mul_sdpa_rule = SDPA.rule( |
188 |
| - "masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=True |
189 |
| -) |
190 |
| -masked_post_div_sdpa_rule = SDPA.rule( |
191 |
| - "masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=False |
192 |
| -) |
193 |
| -masked_post_mul_sdpa_rule = SDPA.rule( |
194 |
| - "masked_post_mul_sdpa", use_mask=True, pre_scale=False, use_mul=True |
195 |
| -) |
196 |
| - |
| 171 | +parameter_combinations = [ |
| 172 | + { |
| 173 | + "name": f"sdpa_{'masked_' if use_mask else 'unmasked_'}{'pre_' if pre_scale else 'post_'}{'only_q_' if pre_scale_q else ''}{'mul' if use_mul else 'div'}{'_3d' if has_3d_inputs else ''}", |
| 174 | + "use_mask": use_mask, |
| 175 | + "pre_scale": pre_scale, |
| 176 | + "pre_scale_q": pre_scale_q, |
| 177 | + "use_mul": use_mul, |
| 178 | + "has_3d_inputs": has_3d_inputs, |
| 179 | + } |
| 180 | + for use_mask in [False, True] |
| 181 | + for pre_scale in [False, True] |
| 182 | + for pre_scale_q in [False, True] |
| 183 | + for use_mul in [False, True] |
| 184 | + for has_3d_inputs in [False, True] |
| 185 | +] |
| 186 | + |
| 187 | +# Dynamically create the rules |
197 | 188 | sdpa_rules = pattern.RewriteRuleSet(
|
198 | 189 | [
|
199 | 190 | SDPA.rule(
|
|
0 commit comments