Skip to content

Commit ca15d07

Browse files
typo: fix
1 parent 8cec4c8 commit ca15d07

File tree

1 file changed

+2
-2
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+2
-2
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
8484
"unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=False
8585
)
8686
unmasked_post_mul_sdpa_rule = SDPA.rule(
87-
"unmasked_post_div_sdpa", use_mask=False, pre_scale=False, use_mul=True
87+
"unmasked_post_mul_sdpa", use_mask=False, pre_scale=False, use_mul=True
8888
)
8989

9090
# Rules for SDPA with mask
@@ -98,7 +98,7 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):
9898
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=False
9999
)
100100
masked_post_mul_sdpa_rule = SDPA.rule(
101-
"masked_post_div_sdpa", use_mask=True, pre_scale=False, use_mul=True
101+
"masked_post_mul_sdpa", use_mask=True, pre_scale=False, use_mul=True
102102
)
103103

104104
sdpa_rules = pattern.RewriteRuleSet(

0 commit comments

Comments
 (0)