Skip to content

Extend sdpa tests #2118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def rewrite(self, op, query, key_transposed, value, mask, **_):


masked_pre_div_sdpa_rule = SDPA.rule(
"masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=False
"masked_pre_div_sdpa", use_mask=True, pre_scale=True, use_mul=False
)
masked_pre_mul_sdpa_rule = SDPA.rule(
"masked_pre_mul_sdpa", use_mask=True, pre_scale=True, use_mul=True
Expand Down
66 changes: 58 additions & 8 deletions onnxscript/rewriter/ort_fusions/sdpa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unittest

import numpy
from parameterized import parameterized

import onnxscript.ir as ir
import onnxscript.optimizer
Expand All @@ -22,7 +23,9 @@
S = 8 # sequence length
H = 128 # head size
SCALE_FACTOR = math.sqrt(H)
MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR)
SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR)


@script()
Expand All @@ -38,16 +41,55 @@
return attn_output


class _MaskedPreDivSDPATestCase:
@script()
def _masked_pre_mul_sdpa_script(query, key, value, mask):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR)
scaled_query = op.Mul(query, multiplier)
scaled_key = op.Mul(key_transposed, multiplier)
attn_score = op.MatMul(scaled_query, scaled_key)
masked_attn_score = op.Add(attn_score, mask)
attn_weight = op.Softmax(masked_attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output

Check warning on line 54 in onnxscript/rewriter/ort_fusions/sdpa_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/sdpa_test.py#L46-L54

Added lines #L46 - L54 were not covered by tests


@script()
def _masked_post_div_sdpa_script(query, key, value, mask):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
divisor = op.Constant(value_float=SCALE_FACTOR)
attn_score = op.MatMul(query, key_transposed)
scaled_attn_score = op.Div(attn_score, divisor)
masked_attn_score = op.Add(scaled_attn_score, mask)
attn_weight = op.Softmax(masked_attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output

Check warning on line 66 in onnxscript/rewriter/ort_fusions/sdpa_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/sdpa_test.py#L59-L66

Added lines #L59 - L66 were not covered by tests


@script()
def _masked_post_mul_sdpa_script(query, key, value, mask):
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
multiplier = op.Constant(value_float=MUL_SCALE_FACTOR)
attn_score = op.MatMul(query, key_transposed)
scaled_attn_score = op.Mul(attn_score, multiplier)
masked_attn_score = op.Add(scaled_attn_score, mask)
attn_weight = op.Softmax(masked_attn_score, axis=-1)
attn_output = op.MatMul(attn_weight, value)
return attn_output

Check warning on line 78 in onnxscript/rewriter/ort_fusions/sdpa_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/sdpa_test.py#L71-L78

Added lines #L71 - L78 were not covered by tests


class SDPATestCase:
def __init__(self, script_func):
self.script_func = script_func

def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
qkv_type = FLOAT[B, N, S, H]
mask_type = FLOAT[B, N, S, S]
model_proto = _masked_pre_div_sdpa_script.to_model_proto(
model_proto = self.script_func.to_model_proto(
input_types=[qkv_type, qkv_type, qkv_type, mask_type], output_types=[qkv_type]
)
model = ir.serde.deserialize_model(model_proto)
self._onnx_model = model
self._onnx_model = ir.serde.deserialize_model(model_proto)
return self._onnx_model

def get_ort_inputs(self):
Expand All @@ -63,12 +105,20 @@


class TestSDPAFusion(unittest.TestCase):
def test_sdpa_fusion(self):
test = _MaskedPreDivSDPATestCase()
model = test.get_onnx_model()
@parameterized.expand(
[
("pre_div", _masked_pre_div_sdpa_script),
("pre_mul", _masked_pre_mul_sdpa_script),
("post_div", _masked_post_div_sdpa_script),
("post_mul", _masked_post_mul_sdpa_script),
]
)
def test_sdpa_fusion(self, name, script_func):
test_case = SDPATestCase(script_func)
model = test_case.get_onnx_model()
onnxscript.optimizer.optimize(model)

# inputs = test.get_ort_inputs()
# inputs = test_case.get_ort_inputs()
# original_outputs = ort_run("original", model, inputs)

count = fuse_sdpa(model)
Expand Down
Loading