Skip to content

Commit 2a4ed0d

Browse files
add quality check
1 parent 5f386b4 commit 2a4ed0d

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,13 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
7474
key_scale,
7575
)
7676
if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3):
77+
if _ir_utils.get_singleton_value(query_scale) != _ir_utils.get_singleton_value(
78+
key_scale
79+
):
80+
return check_result.fail(
81+
"Query and key scales are not equal.",
82+
query_scale,
83+
)
7784
self._custom_scale = True
7885
else:
7986
# Check if qk_scale is a scalar == expected_scaling_factor)

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,19 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
9898
return attn_output
9999

100100

101+
@script()
102+
def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
103+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
104+
multiplier_q = op.Constant(value_float=0.5)
105+
multiplier_k = op.Constant(value_float=0.5)
106+
scaled_query = op.Mul(query, multiplier_q)
107+
scaled_key = op.Mul(key_transposed, multiplier_k)
108+
attn_score = op.MatMul(scaled_query, scaled_key)
109+
attn_weight = op.Softmax(attn_score, axis=-1)
110+
attn_output = op.MatMul(attn_weight, value)
111+
return attn_output
112+
113+
101114
@script()
102115
def _custom_scale_post_div_sdpa_script(query, key, value):
103116
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
@@ -265,6 +278,7 @@ class TestSDPAFusion(unittest.TestCase):
265278
("custom_scale_post_div_masked", _custom_scale_post_div_sdpa_script),
266279
("custom_scale_pre_mul_masked", _custom_scale_pre_mul_sdpa_script),
267280
("custom_scale_pre_div_masked", _custom_scale_pre_div_sdpa_script),
281+
(_custom_multi_scale_pre_mul_sdpa_script, _custom_multi_scale_pre_mul_sdpa_script),
268282
]
269283
)
270284
def test_sdpa_fusion(self, name, script_func):

0 commit comments

Comments
 (0)