Skip to content

Commit e69c5ad

Browse files
authored
Add SDPA fusion unit test case (#2116)
Add SDPA fusion unit test case
1 parent 3d8f64a commit e69c5ad

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""SDPA fusion test cases."""
5+
6+
from __future__ import annotations
7+
8+
import math
9+
import unittest
10+
11+
import numpy
12+
13+
import onnxscript.ir as ir
14+
import onnxscript.optimizer
15+
from onnxscript import script
16+
from onnxscript.onnx_opset import opset18 as op
17+
from onnxscript.onnx_types import FLOAT
18+
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
19+
20+
B = 2 # batch size
21+
N = 4 # number of heads
22+
S = 8 # sequence length
23+
H = 128 # head size
24+
SCALE_FACTOR = math.sqrt(H)
25+
SQRT_SCALE_FACTOR = math.sqrt(SCALE_FACTOR)
26+
27+
28+
@script()
29+
def _masked_pre_div_sdpa_script(query, key, value, mask):
30+
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
31+
divisor = op.Constant(value_float=SQRT_SCALE_FACTOR)
32+
scaled_query = op.Div(query, divisor)
33+
scaled_key = op.Div(key_transposed, divisor)
34+
attn_score = op.MatMul(scaled_query, scaled_key)
35+
masked_attn_score = op.Add(attn_score, mask)
36+
attn_weight = op.Softmax(masked_attn_score, axis=-1)
37+
attn_output = op.MatMul(attn_weight, value)
38+
return attn_output
39+
40+
41+
class _MaskedPreDivSDPATestCase:
42+
def get_onnx_model(self):
43+
if not hasattr(self, "_onnx_model"):
44+
qkv_type = FLOAT[B, N, S, H]
45+
mask_type = FLOAT[B, N, S, S]
46+
model_proto = _masked_pre_div_sdpa_script.to_model_proto(
47+
input_types=[qkv_type, qkv_type, qkv_type, mask_type], output_types=[qkv_type]
48+
)
49+
model = ir.serde.deserialize_model(model_proto)
50+
self._onnx_model = model
51+
return self._onnx_model
52+
53+
def get_ort_inputs(self):
54+
if not hasattr(self, "_ort_inputs"):
55+
inputs = {
56+
"query": numpy.random.rand(B, N, S, H).astype(numpy.float32),
57+
"key": numpy.random.rand(B, N, S, H).astype(numpy.float32),
58+
"value": numpy.random.rand(B, N, S, H).astype(numpy.float32),
59+
"mask": numpy.random.rand(B, N, S, S).astype(numpy.float32),
60+
}
61+
self._ort_inputs = inputs
62+
return self._ort_inputs
63+
64+
65+
class TestSDPAFusion(unittest.TestCase):
66+
def test_sdpa_fusion(self):
67+
test = _MaskedPreDivSDPATestCase()
68+
model = test.get_onnx_model()
69+
onnxscript.optimizer.optimize(model)
70+
71+
# inputs = test.get_ort_inputs()
72+
# original_outputs = ort_run("original", model, inputs)
73+
74+
count = fuse_sdpa(model)
75+
self.assertGreater(count, 0)
76+
77+
# Check that the fusion was successful
78+
op_types = [n.op_type for n in model.graph]
79+
self.assertIn("SDPA", op_types)
80+
81+
# new_outputs = ort_run("optimized", model, inputs)
82+
# assert_allclose(new_outputs, original_outputs)

0 commit comments

Comments
 (0)