Skip to content

Commit c4deac2

Browse files
shubhambhokare1bmehta001
authored andcommitted
Extend sdpa tests (microsoft#2118)
Add tests for: - Pre-mul - Post-div - Post-mul
1 parent 8529ce4 commit c4deac2

File tree

1 file changed

+58
-8
lines changed

1 file changed

+58
-8
lines changed

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import unittest
1010

1111
import numpy
12+
from parameterized import parameterized
1213

1314
import onnxscript.ir as ir
1415
import onnxscript.optimizer
@@ -22,7 +23,9 @@
2223
S = 8 # sequence length
2324
H = 128 # head size
2425
SCALE_FACTOR = math.sqrt(H)
26+
MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
2527
SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR)
28+
SQRT_MUL_SCALE_FACTOR = math.sqrt(MUL_SCALE_FACTOR)
2629

2730

2831
@script()
@@ -38,16 +41,55 @@ def _masked_pre_div_sdpa_script(query, key, value, mask):
3841
return attn_output
3942

4043

41-
class _MaskedPreDivSDPATestCase:
44+
@script()
45+
def _masked_pre_mul_sdpa_script(query, key, value, mask):
46+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
47+
multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR)
48+
scaled_query = op.Mul(query, multiplier)
49+
scaled_key = op.Mul(key_transposed, multiplier)
50+
attn_score = op.MatMul(scaled_query, scaled_key)
51+
masked_attn_score = op.Add(attn_score, mask)
52+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
53+
attn_output = op.MatMul(attn_weight, value)
54+
return attn_output
55+
56+
57+
@script()
58+
def _masked_post_div_sdpa_script(query, key, value, mask):
59+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
60+
divisor = op.Constant(value_float=SCALE_FACTOR)
61+
attn_score = op.MatMul(query, key_transposed)
62+
scaled_attn_score = op.Div(attn_score, divisor)
63+
masked_attn_score = op.Add(scaled_attn_score, mask)
64+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
65+
attn_output = op.MatMul(attn_weight, value)
66+
return attn_output
67+
68+
69+
@script()
70+
def _masked_post_mul_sdpa_script(query, key, value, mask):
71+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
72+
multiplier = op.Constant(value_float=MUL_SCALE_FACTOR)
73+
attn_score = op.MatMul(query, key_transposed)
74+
scaled_attn_score = op.Mul(attn_score, multiplier)
75+
masked_attn_score = op.Add(scaled_attn_score, mask)
76+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
77+
attn_output = op.MatMul(attn_weight, value)
78+
return attn_output
79+
80+
81+
class SDPATestCase:
82+
def __init__(self, script_func):
83+
self.script_func = script_func
84+
4285
def get_onnx_model(self):
4386
if not hasattr(self, "_onnx_model"):
4487
qkv_type = FLOAT[B, N, S, H]
4588
mask_type = FLOAT[B, N, S, S]
46-
model_proto = _masked_pre_div_sdpa_script.to_model_proto(
89+
model_proto = self.script_func.to_model_proto(
4790
input_types=[qkv_type, qkv_type, qkv_type, mask_type], output_types=[qkv_type]
4891
)
49-
model = ir.serde.deserialize_model(model_proto)
50-
self._onnx_model = model
92+
self._onnx_model = ir.serde.deserialize_model(model_proto)
5193
return self._onnx_model
5294

5395
def get_ort_inputs(self):
@@ -63,12 +105,20 @@ def get_ort_inputs(self):
63105

64106

65107
class TestSDPAFusion(unittest.TestCase):
66-
def test_sdpa_fusion(self):
67-
test = _MaskedPreDivSDPATestCase()
68-
model = test.get_onnx_model()
108+
@parameterized.expand(
109+
[
110+
("pre_div", _masked_pre_div_sdpa_script),
111+
("pre_mul", _masked_pre_mul_sdpa_script),
112+
("post_div", _masked_post_div_sdpa_script),
113+
("post_mul", _masked_post_mul_sdpa_script),
114+
]
115+
)
116+
def test_sdpa_fusion(self, name, script_func):
117+
test_case = SDPATestCase(script_func)
118+
model = test_case.get_onnx_model()
69119
onnxscript.optimizer.optimize(model)
70120

71-
# inputs = test.get_ort_inputs()
121+
# inputs = test_case.get_ort_inputs()
72122
# original_outputs = ort_run("original", model, inputs)
73123

74124
count = fuse_sdpa(model)

0 commit comments

Comments
 (0)