Skip to content

Allow sdpa fusion to accept custom scale factor #2210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions onnxscript/rewriter/ort_fusions/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool)
self._use_mask = use_mask
self._pre_scale = pre_scale
self._use_mul = use_mul
self._scale: float | None = None

def pattern(
self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale
Expand Down Expand Up @@ -57,34 +58,53 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,

if self._pre_scale:
# Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor)
# If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used.
sqrt_scaling_factor = math.sqrt(expected_scaling_factor)
if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3):
# Calculate the scaling factor for query
if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None:
return check_result.fail(
"Query scale is not a scalar or does not match the expected scaling factor.",
"Query scale is not a scalar.",
query_scale,
)
if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3):
# Ensure the scaling factor for key is the same as for query
if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None:
return check_result.fail(
"Key scale is not a scalar or does not match the expected scaling factor.",
"Key scale is not a scalar.",
key_scale,
)
if not math.isclose(query_scale_value, key_scale_value, rel_tol=1e-3):
return check_result.fail(
"Query and key scales are not equal.",
query_scale,
)
if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3):
self._scale = query_scale_value * query_scale_value
else:
# Pass no scaling factor to SDPA, SDPA will use the default scaling factor
self._scale = None
else:
# Check if qk_scale is a scalar == expected_scaling_factor)
if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3):
# If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used
if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None:
return check_result.fail(
"QK scale is not a scalar or does not match the expected scaling factor.",
"QK scale is not a scalar.",
qk_scale,
)
if not math.isclose(qk_scale_value, expected_scaling_factor, rel_tol=1e-3):
self._scale = qk_scale_value
else:
# Pass no scaling factor to SDPA, SDPA will use the default scaling factor
self._scale = None

# check ranks/shapes

return check_result

def rewrite(self, op, query, key_transposed, value, mask, **_):
sdpa_args = [query, key_transposed, value]
if self._use_mask:
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
else:
return op.SDPA(query, key_transposed, value, _domain="ai.onnxruntime.fusion")
sdpa_args.append(mask)
return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion")


# Rules for SDPA without mask
Expand Down
140 changes: 140 additions & 0 deletions onnxscript/rewriter/ort_fusions/sdpa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR)
SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR)
CUSTOM_SCALE_FACTOR = 2.0


@script()
Expand Down Expand Up @@ -74,6 +75,65 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
return attn_output


@script()
def _custom_scale_pre_div_sdpa_script(query, key, value):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
scaled_query = op.Div(query, divisor)
scaled_key = op.Div(key_transposed, divisor)
attn_score = op.MatMul(scaled_query, scaled_key)
attn_weight = op.Softmax(attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output


@script()
def _custom_scale_pre_mul_sdpa_script(query, key, value):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
scaled_query = op.Mul(query, multiplier)
scaled_key = op.Mul(key_transposed, multiplier)
attn_score = op.MatMul(scaled_query, scaled_key)
attn_weight = op.Softmax(attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output


@script()
def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
multiplier_q = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
multiplier_k = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
scaled_query = op.Mul(query, multiplier_q)
scaled_key = op.Mul(key_transposed, multiplier_k)
attn_score = op.MatMul(scaled_query, scaled_key)
attn_weight = op.Softmax(attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output


@script()
def _custom_scale_post_div_sdpa_script(query, key, value):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
attn_score = op.MatMul(query, key_transposed)
scaled_attn_score = op.Div(attn_score, divisor)
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output


@script()
def _custom_scale_post_mul_sdpa_script(query, key, value):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
attn_score = op.MatMul(query, key_transposed)
scaled_attn_score = op.Mul(attn_score, multiplier)
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output


@script()
def _masked_pre_div_sdpa_script(query, key, value, mask):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
Expand Down Expand Up @@ -124,6 +184,56 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
return attn_output


@script()
def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
scaled_query = op.Div(query, divisor)
scaled_key = op.Div(key_transposed, divisor)
attn_score = op.MatMul(scaled_query, scaled_key)
masked_attn_score = op.Add(attn_score, mask)
attn_weight = op.Softmax(masked_attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output


@script()
def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
scaled_query = op.Mul(query, multiplier)
scaled_key = op.Mul(key_transposed, multiplier)
attn_score = op.MatMul(scaled_query, scaled_key)
masked_attn_score = op.Add(attn_score, mask)
attn_weight = op.Softmax(masked_attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output


@script()
def _custom_scale_post_div_sdpa_script(query, key, value, mask):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
attn_score = op.MatMul(query, key_transposed)
scaled_attn_score = op.Div(attn_score, divisor)
masked_attn_score = op.Add(scaled_attn_score, mask)
attn_weight = op.Softmax(masked_attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output


@script()
def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR)
attn_score = op.MatMul(query, key_transposed)
scaled_attn_score = op.Mul(attn_score, multiplier)
masked_attn_score = op.Add(scaled_attn_score, mask)
attn_weight = op.Softmax(masked_attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output


class SDPATestCase:
def __init__(self, script_func):
self.script_func = script_func
Expand Down Expand Up @@ -161,6 +271,18 @@ class TestSDPAFusion(unittest.TestCase):
("pre_mul", _masked_pre_mul_sdpa_script),
("post_div", _masked_post_div_sdpa_script),
("post_mul", _masked_post_mul_sdpa_script),
("custom_scale_post_mul", _custom_scale_post_mul_sdpa_script),
("custom_scale_post_div", _custom_scale_post_div_sdpa_script),
("custom_scale_pre_mul", _custom_scale_pre_mul_sdpa_script),
("custom_scale_pre_div", _custom_scale_pre_div_sdpa_script),
("custom_scale_post_mul_masked", _custom_scale_post_mul_sdpa_script),
("custom_scale_post_div_masked", _custom_scale_post_div_sdpa_script),
("custom_scale_pre_mul_masked", _custom_scale_pre_mul_sdpa_script),
("custom_scale_pre_div_masked", _custom_scale_pre_div_sdpa_script),
(
"_custom_multi_scale_pre_mul_sdpa_script",
_custom_multi_scale_pre_mul_sdpa_script,
),
]
)
def test_sdpa_fusion(self, name, script_func):
Expand All @@ -178,6 +300,24 @@ def test_sdpa_fusion(self, name, script_func):
op_types = [n.op_type for n in model.graph]
self.assertIn("SDPA", op_types)

# Ensure that the scale of the SDPA node is set correctly
sdpa_node = next(n for n in model.graph if n.op_type == "SDPA")
self.assertEqual(sdpa_node.op_type, "SDPA")

if "custom" in name:
self.assertIsNotNone(sdpa_node.attributes.get("scale"))
scale_factor = sdpa_node.attributes["scale"].value
self.assertIsNotNone(scale_factor)
if "pre" in name:
self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR)
elif "post" in name:
self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR)
else:
# These tests are for the default scaling factors, no scale factor is passed to SDPA
# pattern rewriting check functions should be sufficient to check if expected value
# of scale_factor (is =default_scaling_factor)
self.assertIsNone(sdpa_node.attributes.get("scale"))

# new_outputs = ort_run("optimized", model, inputs)
# assert_allclose(new_outputs, original_outputs)

Expand Down
Loading