Skip to content

Commit fc2e5da

Browse files
Allow sdpa fusion to accept custom scale factor (#2210)
1 parent f5327f8 commit fc2e5da

File tree

2 files changed

+169
-9
lines changed

2 files changed

+169
-9
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool)
1313
self._use_mask = use_mask
1414
self._pre_scale = pre_scale
1515
self._use_mul = use_mul
16+
self._scale: float | None = None
1617

1718
def pattern(
1819
self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale
@@ -57,34 +58,53 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
5758

5859
if self._pre_scale:
5960
# Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor)
61+
# If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used.
6062
sqrt_scaling_factor = math.sqrt(expected_scaling_factor)
61-
if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3):
63+
# Calculate the scaling factor for query
64+
if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None:
6265
return check_result.fail(
63-
"Query scale is not a scalar or does not match the expected scaling factor.",
66+
"Query scale is not a scalar.",
6467
query_scale,
6568
)
66-
if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3):
69+
# Ensure the scaling factor for key is the same as for query
70+
if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None:
6771
return check_result.fail(
68-
"Key scale is not a scalar or does not match the expected scaling factor.",
72+
"Key scale is not a scalar.",
6973
key_scale,
7074
)
75+
if not math.isclose(query_scale_value, key_scale_value, rel_tol=1e-3):
76+
return check_result.fail(
77+
"Query and key scales are not equal.",
78+
query_scale,
79+
)
80+
if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3):
81+
self._scale = query_scale_value * query_scale_value
82+
else:
83+
# Pass no scaling factor to SDPA, SDPA will use the default scaling factor
84+
self._scale = None
7185
else:
7286
# Check if qk_scale is a scalar == expected_scaling_factor)
73-
if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3):
87+
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
88+
if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None:
7489
return check_result.fail(
75-
"QK scale is not a scalar or does not match the expected scaling factor.",
90+
"QK scale is not a scalar.",
7691
qk_scale,
7792
)
93+
if not math.isclose(qk_scale_value, expected_scaling_factor, rel_tol=1e-3):
94+
self._scale = qk_scale_value
95+
else:
96+
# Pass no scaling factor to SDPA, SDPA will use the default scaling factor
97+
self._scale = None
7898

7999
# check ranks/shapes
80100

81101
return check_result
82102

83103
def rewrite(self, op, query, key_transposed, value, mask, **_):
104+
sdpa_args = [query, key_transposed, value]
84105
if self._use_mask:
85-
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
86-
else:
87-
return op.SDPA(query, key_transposed, value, _domain="ai.onnxruntime.fusion")
106+
sdpa_args.append(mask)
107+
return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion")
88108

89109

90110
# Rules for SDPA without mask

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
2727
SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR)
2828
SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR)
29+
CUSTOM_SCALE_FACTOR = 2.0
2930

3031

3132
@script()
@@ -74,6 +75,65 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
7475
return attn_output
7576

7677

78+
@script()
79+
def _custom_scale_pre_div_sdpa_script(query, key, value):
80+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
81+
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
82+
scaled_query = op.Div(query, divisor)
83+
scaled_key = op.Div(key_transposed, divisor)
84+
attn_score = op.MatMul(scaled_query, scaled_key)
85+
attn_weight = op.Softmax(attn_score, axis=-1)
86+
attn_output = op.MatMul(attn_weight, value)
87+
return attn_output
88+
89+
90+
@script()
91+
def _custom_scale_pre_mul_sdpa_script(query, key, value):
92+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
93+
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
94+
scaled_query = op.Mul(query, multiplier)
95+
scaled_key = op.Mul(key_transposed, multiplier)
96+
attn_score = op.MatMul(scaled_query, scaled_key)
97+
attn_weight = op.Softmax(attn_score, axis=-1)
98+
attn_output = op.MatMul(attn_weight, value)
99+
return attn_output
100+
101+
102+
@script()
103+
def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
104+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
105+
multiplier_q = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
106+
multiplier_k = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
107+
scaled_query = op.Mul(query, multiplier_q)
108+
scaled_key = op.Mul(key_transposed, multiplier_k)
109+
attn_score = op.MatMul(scaled_query, scaled_key)
110+
attn_weight = op.Softmax(attn_score, axis=-1)
111+
attn_output = op.MatMul(attn_weight, value)
112+
return attn_output
113+
114+
115+
@script()
116+
def _custom_scale_post_div_sdpa_script(query, key, value):
117+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
118+
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
119+
attn_score = op.MatMul(query, key_transposed)
120+
scaled_attn_score = op.Div(attn_score, divisor)
121+
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
122+
attn_output = op.MatMul(attn_weight, value)
123+
return attn_output
124+
125+
126+
@script()
127+
def _custom_scale_post_mul_sdpa_script(query, key, value):
128+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
129+
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
130+
attn_score = op.MatMul(query, key_transposed)
131+
scaled_attn_score = op.Mul(attn_score, multiplier)
132+
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
133+
attn_output = op.MatMul(attn_weight, value)
134+
return attn_output
135+
136+
77137
@script()
78138
def _masked_pre_div_sdpa_script(query, key, value, mask):
79139
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
@@ -124,6 +184,56 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
124184
return attn_output
125185

126186

187+
@script()
188+
def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
189+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
190+
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
191+
scaled_query = op.Div(query, divisor)
192+
scaled_key = op.Div(key_transposed, divisor)
193+
attn_score = op.MatMul(scaled_query, scaled_key)
194+
masked_attn_score = op.Add(attn_score, mask)
195+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
196+
attn_output = op.MatMul(attn_weight, value)
197+
return attn_output
198+
199+
200+
@script()
201+
def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
202+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
203+
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
204+
scaled_query = op.Mul(query, multiplier)
205+
scaled_key = op.Mul(key_transposed, multiplier)
206+
attn_score = op.MatMul(scaled_query, scaled_key)
207+
masked_attn_score = op.Add(attn_score, mask)
208+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
209+
attn_output = op.MatMul(attn_weight, value)
210+
return attn_output
211+
212+
213+
@script()
214+
def _custom_scale_post_div_sdpa_script(query, key, value, mask):
215+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
216+
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
217+
attn_score = op.MatMul(query, key_transposed)
218+
scaled_attn_score = op.Div(attn_score, divisor)
219+
masked_attn_score = op.Add(scaled_attn_score, mask)
220+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
221+
attn_output = op.MatMul(attn_weight, value)
222+
return attn_output
223+
224+
225+
@script()
226+
def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
227+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
228+
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
229+
attn_score = op.MatMul(query, key_transposed)
230+
scaled_attn_score = op.Mul(attn_score, multiplier)
231+
masked_attn_score = op.Add(scaled_attn_score, mask)
232+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
233+
attn_output = op.MatMul(attn_weight, value)
234+
return attn_output
235+
236+
127237
class SDPATestCase:
128238
def __init__(self, script_func):
129239
self.script_func = script_func
@@ -161,6 +271,18 @@ class TestSDPAFusion(unittest.TestCase):
161271
("pre_mul", _masked_pre_mul_sdpa_script),
162272
("post_div", _masked_post_div_sdpa_script),
163273
("post_mul", _masked_post_mul_sdpa_script),
274+
("custom_scale_post_mul", _custom_scale_post_mul_sdpa_script),
275+
("custom_scale_post_div", _custom_scale_post_div_sdpa_script),
276+
("custom_scale_pre_mul", _custom_scale_pre_mul_sdpa_script),
277+
("custom_scale_pre_div", _custom_scale_pre_div_sdpa_script),
278+
("custom_scale_post_mul_masked", _custom_scale_post_mul_sdpa_script),
279+
("custom_scale_post_div_masked", _custom_scale_post_div_sdpa_script),
280+
("custom_scale_pre_mul_masked", _custom_scale_pre_mul_sdpa_script),
281+
("custom_scale_pre_div_masked", _custom_scale_pre_div_sdpa_script),
282+
(
283+
"_custom_multi_scale_pre_mul_sdpa_script",
284+
_custom_multi_scale_pre_mul_sdpa_script,
285+
),
164286
]
165287
)
166288
def test_sdpa_fusion(self, name, script_func):
@@ -178,6 +300,24 @@ def test_sdpa_fusion(self, name, script_func):
178300
op_types = [n.op_type for n in model.graph]
179301
self.assertIn("SDPA", op_types)
180302

303+
# Ensure that the scale of the SDPA node is set correctly
304+
sdpa_node = next(n for n in model.graph if n.op_type == "SDPA")
305+
self.assertEqual(sdpa_node.op_type, "SDPA")
306+
307+
if "custom" in name:
308+
self.assertIsNotNone(sdpa_node.attributes.get("scale"))
309+
scale_factor = sdpa_node.attributes["scale"].value
310+
self.assertIsNotNone(scale_factor)
311+
if "pre" in name:
312+
self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR)
313+
elif "post" in name:
314+
self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR)
315+
else:
316+
# These tests are for the default scaling factors, no scale factor is passed to SDPA
317+
# pattern rewriting check functions should be sufficient to check if expected value
318+
# of scale_factor (is =default_scaling_factor)
319+
self.assertIsNone(sdpa_node.attributes.get("scale"))
320+
181321
# new_outputs = ort_run("optimized", model, inputs)
182322
# assert_allclose(new_outputs, original_outputs)
183323

0 commit comments

Comments
 (0)