Skip to content

Commit 9c87a4c

Browse files
rebase modifs
1 parent 2b8e72e commit 9c87a4c

File tree

2 files changed

+20
-31
lines changed

2 files changed

+20
-31
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
broadcast_to_matmul,
1919
cast_constant_of_shape,
2020
collapse_slices,
21-
erfgelu,
2221
gemm_to_matmul_add,
2322
llama_rule_sets,
2423
no_op,
@@ -29,7 +28,6 @@
2928
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
3029
*no_op.rules.rules, # TODO: merge this rule into constant folding?
3130
*broadcast_to_matmul.rules.rules,
32-
*erfgelu.rules.rules, # type: ignore[has-type]
3331
gemm_to_matmul_add.rule, # type: ignore[has-type]
3432
*cast_constant_of_shape.rules.rules,
3533
*collapse_slices.rules.rules,

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
from onnxscript.rewriter import _fusion_utils, pattern
5+
import math
6+
7+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
68

79

810
class SDPA(pattern.RewriteRuleClassBase):
@@ -166,34 +168,23 @@ def rewrite(
166168
return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion")
167169

168170

169-
# Rules for SDPA without mask
170-
unmasked_pre_div_sdpa_rule = SDPA.rule(
171-
"unmasked_pre_div_sdpa", use_mask=False, pre_scale=True, use_mul=False
172-
)
173-
unmasked_pre_mul_sdpa_rule = SDPA.rule(
174-
"unmasked_pre_mul_sdpa", use_mask=False, pre_scale=True, use_mul=True
175-
)
176-
unmasked_post_div_sdpa_rule = SDPA.rule(
177-
"unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=False
178-
)
179-
unmasked_post_mul_sdpa_rule = SDPA.rule(
180-
"unmasked_post_mul_sdpa", use_mask=False, pre_scale=False, use_mul=True
181-
)
182-
183-
# Rules for SDPA with mask
184-
masked_pre_div_sdpa_rule = SDPA.rule(
185-
"masked_pre_div_sdpa", use_mask=True, pre_scale=True, use_mul=False
186-
)
187-
masked_pre_mul_sdpa_rule = SDPA.rule(
188-
"masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=True
189-
)
190-
masked_post_div_sdpa_rule = SDPA.rule(
191-
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=False
192-
)
193-
masked_post_mul_sdpa_rule = SDPA.rule(
194-
"masked_post_mul_sdpa", use_mask=True, pre_scale=False, use_mul=True
195-
)
196-
171+
parameter_combinations = [
172+
{
173+
"name": f"sdpa_{'masked_' if use_mask else 'unmasked_'}{'pre_' if pre_scale else 'post_'}{'only_q_' if pre_scale_q else ''}{'mul' if use_mul else 'div'}{'_3d' if has_3d_inputs else ''}",
174+
"use_mask": use_mask,
175+
"pre_scale": pre_scale,
176+
"pre_scale_q": pre_scale_q,
177+
"use_mul": use_mul,
178+
"has_3d_inputs": has_3d_inputs,
179+
}
180+
for use_mask in [False, True]
181+
for pre_scale in [False, True]
182+
for pre_scale_q in [False, True]
183+
for use_mul in [False, True]
184+
for has_3d_inputs in [False, True]
185+
]
186+
187+
# Dynamically create the rules
197188
sdpa_rules = pattern.RewriteRuleSet(
198189
[
199190
SDPA.rule(

0 commit comments

Comments
 (0)