Skip to content

Commit e62345a

Browse files
add scale correctness test
1 parent a23486c commit e62345a

File tree

2 files changed

+34
-12
lines changed

2 files changed

+34
-12
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
6666
"Query scale is not a scalar.",
6767
query_scale,
6868
)
69-
if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3):
69+
if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3):
7070
self._scale = query_scale_value * query_scale_value
7171
else:
7272
self._scale = expected_scaling_factor

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
2727
SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR)
2828
SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR)
29+
CUSTOM_SCALE_FACTOR = 2.0
2930

3031

3132
@script()
@@ -77,7 +78,7 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
7778
@script()
7879
def _custom_scale_pre_div_sdpa_script(query, key, value):
7980
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
80-
divisor = op.Constant(value_float=2.0)
81+
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
8182
scaled_query = op.Div(query, divisor)
8283
scaled_key = op.Div(key_transposed, divisor)
8384
attn_score = op.MatMul(scaled_query, scaled_key)
@@ -89,7 +90,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value):
8990
@script()
9091
def _custom_scale_pre_mul_sdpa_script(query, key, value):
9192
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
92-
multiplier = op.Constant(value_float=0.5)
93+
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
9394
scaled_query = op.Mul(query, multiplier)
9495
scaled_key = op.Mul(key_transposed, multiplier)
9596
attn_score = op.MatMul(scaled_query, scaled_key)
@@ -101,8 +102,8 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
101102
@script()
102103
def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
103104
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
104-
multiplier_q = op.Constant(value_float=0.5)
105-
multiplier_k = op.Constant(value_float=0.5)
105+
multiplier_q = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
106+
multiplier_k = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
106107
scaled_query = op.Mul(query, multiplier_q)
107108
scaled_key = op.Mul(key_transposed, multiplier_k)
108109
attn_score = op.MatMul(scaled_query, scaled_key)
@@ -114,7 +115,7 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
114115
@script()
115116
def _custom_scale_post_div_sdpa_script(query, key, value):
116117
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
117-
divisor = op.Constant(value_float=0.1)
118+
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
118119
attn_score = op.MatMul(query, key_transposed)
119120
scaled_attn_score = op.Div(attn_score, divisor)
120121
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
@@ -125,7 +126,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value):
125126
@script()
126127
def _custom_scale_post_mul_sdpa_script(query, key, value):
127128
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
128-
multiplier = op.Constant(value_float=0.125)
129+
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
129130
attn_score = op.MatMul(query, key_transposed)
130131
scaled_attn_score = op.Mul(attn_score, multiplier)
131132
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
@@ -186,7 +187,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
186187
@script()
187188
def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
188189
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
189-
divisor = op.Constant(value_float=2.0)
190+
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
190191
scaled_query = op.Div(query, divisor)
191192
scaled_key = op.Div(key_transposed, divisor)
192193
attn_score = op.MatMul(scaled_query, scaled_key)
@@ -199,7 +200,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
199200
@script()
200201
def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
201202
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
202-
multiplier = op.Constant(value_float=0.5)
203+
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
203204
scaled_query = op.Mul(query, multiplier)
204205
scaled_key = op.Mul(key_transposed, multiplier)
205206
attn_score = op.MatMul(scaled_query, scaled_key)
@@ -212,7 +213,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
212213
@script()
213214
def _custom_scale_post_div_sdpa_script(query, key, value, mask):
214215
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
215-
divisor = op.Constant(value_float=0.1)
216+
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
216217
attn_score = op.MatMul(query, key_transposed)
217218
scaled_attn_score = op.Div(attn_score, divisor)
218219
masked_attn_score = op.Add(scaled_attn_score, mask)
@@ -224,7 +225,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask):
224225
@script()
225226
def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
226227
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
227-
multiplier = op.Constant(value_float=0.125)
228+
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
228229
attn_score = op.MatMul(query, key_transposed)
229230
scaled_attn_score = op.Mul(attn_score, multiplier)
230231
masked_attn_score = op.Add(scaled_attn_score, mask)
@@ -278,7 +279,10 @@ class TestSDPAFusion(unittest.TestCase):
278279
("custom_scale_post_div_masked", _custom_scale_post_div_sdpa_script),
279280
("custom_scale_pre_mul_masked", _custom_scale_pre_mul_sdpa_script),
280281
("custom_scale_pre_div_masked", _custom_scale_pre_div_sdpa_script),
281-
("_custom_multi_scale_pre_mul_sdpa_script", _custom_multi_scale_pre_mul_sdpa_script),
282+
(
283+
"_custom_multi_scale_pre_mul_sdpa_script",
284+
_custom_multi_scale_pre_mul_sdpa_script,
285+
),
282286
]
283287
)
284288
def test_sdpa_fusion(self, name, script_func):
@@ -296,6 +300,24 @@ def test_sdpa_fusion(self, name, script_func):
296300
op_types = [n.op_type for n in model.graph]
297301
self.assertIn("SDPA", op_types)
298302

303+
# Ensure that the scale of the SDPA node is set correctly
304+
sdpa_node = next(n for n in model.graph if n.op_type == "SDPA")
305+
self.assertEqual(sdpa_node.op_type, "SDPA")
306+
self.assertIsNotNone(sdpa_node.attributes.get("scale"))
307+
308+
scale_factor = sdpa_node.attributes["scale"].value
309+
self.assertIsNotNone(scale_factor)
310+
if "custom" in name:
311+
if "pre" in name:
312+
self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR)
313+
elif "post" in name:
314+
self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR)
315+
else:
316+
if "div" in name:
317+
self.assertEqual(scale_factor, SCALE_FACTOR)
318+
elif "mul" in name:
319+
self.assertEqual(scale_factor, MUL_SCALE_FACTOR)
320+
299321
# new_outputs = ort_run("optimized", model, inputs)
300322
# assert_allclose(new_outputs, original_outputs)
301323

0 commit comments

Comments
 (0)