Skip to content

Commit a23486c

Browse files
typos
1 parent d1507f3 commit a23486c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
197197

198198

199199
@script()
200-
def _custom_scale_mul_sdpa_script(query, key, value, mask):
200+
def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
201201
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
202202
multiplier = op.Constant(value_float=0.5)
203203
scaled_query = op.Mul(query, multiplier)
@@ -278,7 +278,7 @@ class TestSDPAFusion(unittest.TestCase):
278278
("custom_scale_post_div_masked", _custom_scale_post_div_sdpa_script),
279279
("custom_scale_pre_mul_masked", _custom_scale_pre_mul_sdpa_script),
280280
("custom_scale_pre_div_masked", _custom_scale_pre_div_sdpa_script),
281-
(_custom_multi_scale_pre_mul_sdpa_script, _custom_multi_scale_pre_mul_sdpa_script),
281+
("_custom_multi_scale_pre_mul_sdpa_script", _custom_multi_scale_pre_mul_sdpa_script),
282282
]
283283
)
284284
def test_sdpa_fusion(self, name, script_func):

0 commit comments

Comments
 (0)