diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 6a26afa4c8..faa7b29b38 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -13,6 +13,7 @@ def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool) self._use_mask = use_mask self._pre_scale = pre_scale self._use_mul = use_mul + self._scale: float | None = None def pattern( self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale @@ -57,34 +58,53 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, if self._pre_scale: # Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor) + # If they are scalars but != sqrt(expected_scaling_factor), a custom scale is being used. sqrt_scaling_factor = math.sqrt(expected_scaling_factor) - if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3): + # Calculate the scaling factor for query + if (query_scale_value := _ir_utils.get_singleton_value(query_scale)) is None: return check_result.fail( - "Query scale is not a scalar or does not match the expected scaling factor.", + "Query scale is not a scalar.", query_scale, ) - if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3): + # Ensure the scaling factor for key is the same as for query + if (key_scale_value := _ir_utils.get_singleton_value(key_scale)) is None: return check_result.fail( - "Key scale is not a scalar or does not match the expected scaling factor.", + "Key scale is not a scalar.", key_scale, ) + if not math.isclose(query_scale_value, key_scale_value, rel_tol=1e-3): + return check_result.fail( + "Query and key scales are not equal.", + query_scale, + ) + if not math.isclose(query_scale_value, sqrt_scaling_factor, rel_tol=1e-3): + self._scale = query_scale_value * query_scale_value + else: + # Pass no scaling factor to SDPA, SDPA will use the default scaling factor + self._scale = None else: # Check if qk_scale is a scalar == expected_scaling_factor) - if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3): + # If it is a scalar but != sqrt(expected_scaling_factor), a custom scale is being used + if (qk_scale_value := _ir_utils.get_singleton_value(qk_scale)) is None: return check_result.fail( - "QK scale is not a scalar or does not match the expected scaling factor.", + "QK scale is not a scalar.", qk_scale, ) + if not math.isclose(qk_scale_value, expected_scaling_factor, rel_tol=1e-3): + self._scale = qk_scale_value + else: + # Pass no scaling factor to SDPA, SDPA will use the default scaling factor + self._scale = None # check ranks/shapes return check_result def rewrite(self, op, query, key_transposed, value, mask, **_): + sdpa_args = [query, key_transposed, value] if self._use_mask: - return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion") - else: - return op.SDPA(query, key_transposed, value, _domain="ai.onnxruntime.fusion") + sdpa_args.append(mask) + return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion") # Rules for SDPA without mask diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 1cd79e1c42..74c718147f 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -26,6 +26,7 @@ MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR) SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR) +CUSTOM_SCALE_FACTOR = 2.0 @script() @@ -74,6 +75,65 @@ def _unmasked_post_mul_sdpa_script(query, key, value): return attn_output +@script() +def _custom_scale_pre_div_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _custom_scale_pre_mul_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier) + scaled_key = op.Mul(key_transposed, multiplier) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _custom_multi_scale_pre_mul_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier_q = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + multiplier_k = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier_q) + scaled_key = op.Mul(key_transposed, multiplier_k) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _custom_scale_post_div_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Div(attn_score, divisor) + attn_weight = op.Softmax(scaled_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _custom_scale_post_mul_sdpa_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Mul(attn_score, multiplier) + attn_weight = op.Softmax(scaled_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + @script() def _masked_pre_div_sdpa_script(query, key, value, mask): key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) @@ -124,6 +184,56 @@ def _masked_post_mul_sdpa_script(query, key, value, mask): return attn_output +@script() +def _custom_scale_pre_div_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _custom_scale_pre_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + scaled_query = op.Mul(query, multiplier) + scaled_key = op.Mul(key_transposed, multiplier) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _custom_scale_post_div_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Div(attn_score, divisor) + masked_attn_score = op.Add(scaled_attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + +@script() +def _custom_scale_post_mul_sdpa_script(query, key, value, mask): + key_transposed = op.Transpose(key, perm=[0, 1, 3, 2]) + multiplier = op.Constant(value_float=CUSTOM_SCALE_FACTOR) + attn_score = op.MatMul(query, key_transposed) + scaled_attn_score = op.Mul(attn_score, multiplier) + masked_attn_score = op.Add(scaled_attn_score, mask) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + class SDPATestCase: def __init__(self, script_func): self.script_func = script_func @@ -161,6 +271,18 @@ class TestSDPAFusion(unittest.TestCase): ("pre_mul", _masked_pre_mul_sdpa_script), ("post_div", _masked_post_div_sdpa_script), ("post_mul", _masked_post_mul_sdpa_script), + ("custom_scale_post_mul", _custom_scale_post_mul_sdpa_script), + ("custom_scale_post_div", _custom_scale_post_div_sdpa_script), + ("custom_scale_pre_mul", _custom_scale_pre_mul_sdpa_script), + ("custom_scale_pre_div", _custom_scale_pre_div_sdpa_script), + ("custom_scale_post_mul_masked", _custom_scale_post_mul_sdpa_script), + ("custom_scale_post_div_masked", _custom_scale_post_div_sdpa_script), + ("custom_scale_pre_mul_masked", _custom_scale_pre_mul_sdpa_script), + ("custom_scale_pre_div_masked", _custom_scale_pre_div_sdpa_script), + ( + "_custom_multi_scale_pre_mul_sdpa_script", + _custom_multi_scale_pre_mul_sdpa_script, + ), ] ) def test_sdpa_fusion(self, name, script_func): @@ -178,6 +300,24 @@ def test_sdpa_fusion(self, name, script_func): op_types = [n.op_type for n in model.graph] self.assertIn("SDPA", op_types) + # Ensure that the scale of the SDPA node is set correctly + sdpa_node = next(n for n in model.graph if n.op_type == "SDPA") + self.assertEqual(sdpa_node.op_type, "SDPA") + + if "custom" in name: + self.assertIsNotNone(sdpa_node.attributes.get("scale")) + scale_factor = sdpa_node.attributes["scale"].value + self.assertIsNotNone(scale_factor) + if "pre" in name: + self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR) + elif "post" in name: + self.assertEqual(scale_factor, CUSTOM_SCALE_FACTOR) + else: + # These tests are for the default scaling factors, no scale factor is passed to SDPA + # pattern rewriting check functions should be sufficient to check if expected value + # of scale_factor (is =default_scaling_factor) + self.assertIsNone(sdpa_node.attributes.get("scale")) + # new_outputs = ort_run("optimized", model, inputs) # assert_allclose(new_outputs, original_outputs)