9
9
import unittest
10
10
11
11
import numpy
12
+ from parameterized import parameterized
12
13
13
14
import onnxscript .ir as ir
14
15
import onnxscript .optimizer
22
23
S = 8 # sequence length
23
24
H = 128 # head size
24
25
SCALE_FACTOR = math .sqrt (H )
26
+ MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
25
27
SQRT_SCALE_FACTOR = math .sqrt (SCALE_FACTOR )
28
+ SQRT_MUL_SCALE_FACTOR = math .sqrt (MUL_SCALE_FACTOR )
26
29
27
30
28
31
@script ()
@@ -38,16 +41,55 @@ def _masked_pre_div_sdpa_script(query, key, value, mask):
38
41
return attn_output
39
42
40
43
41
- class _MaskedPreDivSDPATestCase :
44
+ @script ()
45
+ def _masked_pre_mul_sdpa_script (query , key , value , mask ):
46
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
47
+ multiplier = op .Constant (value_float = SQRT_MUL_SCALE_FACTOR )
48
+ scaled_query = op .Mul (query , multiplier )
49
+ scaled_key = op .Mul (key_transposed , multiplier )
50
+ attn_score = op .MatMul (scaled_query , scaled_key )
51
+ masked_attn_score = op .Add (attn_score , mask )
52
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
53
+ attn_output = op .MatMul (attn_weight , value )
54
+ return attn_output
55
+
56
+
57
+ @script ()
58
+ def _masked_post_div_sdpa_script (query , key , value , mask ):
59
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
60
+ divisor = op .Constant (value_float = SCALE_FACTOR )
61
+ attn_score = op .MatMul (query , key_transposed )
62
+ scaled_attn_score = op .Div (attn_score , divisor )
63
+ masked_attn_score = op .Add (scaled_attn_score , mask )
64
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
65
+ attn_output = op .MatMul (attn_weight , value )
66
+ return attn_output
67
+
68
+
69
+ @script ()
70
+ def _masked_post_mul_sdpa_script (query , key , value , mask ):
71
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
72
+ multiplier = op .Constant (value_float = MUL_SCALE_FACTOR )
73
+ attn_score = op .MatMul (query , key_transposed )
74
+ scaled_attn_score = op .Mul (attn_score , multiplier )
75
+ masked_attn_score = op .Add (scaled_attn_score , mask )
76
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
77
+ attn_output = op .MatMul (attn_weight , value )
78
+ return attn_output
79
+
80
+
81
+ class SDPATestCase :
82
+ def __init__ (self , script_func ):
83
+ self .script_func = script_func
84
+
42
85
def get_onnx_model (self ):
43
86
if not hasattr (self , "_onnx_model" ):
44
87
qkv_type = FLOAT [B , N , S , H ]
45
88
mask_type = FLOAT [B , N , S , S ]
46
- model_proto = _masked_pre_div_sdpa_script .to_model_proto (
89
+ model_proto = self . script_func .to_model_proto (
47
90
input_types = [qkv_type , qkv_type , qkv_type , mask_type ], output_types = [qkv_type ]
48
91
)
49
- model = ir .serde .deserialize_model (model_proto )
50
- self ._onnx_model = model
92
+ self ._onnx_model = ir .serde .deserialize_model (model_proto )
51
93
return self ._onnx_model
52
94
53
95
def get_ort_inputs (self ):
@@ -63,12 +105,20 @@ def get_ort_inputs(self):
63
105
64
106
65
107
class TestSDPAFusion (unittest .TestCase ):
66
- def test_sdpa_fusion (self ):
67
- test = _MaskedPreDivSDPATestCase ()
68
- model = test .get_onnx_model ()
108
+ @parameterized .expand (
109
+ [
110
+ ("pre_div" , _masked_pre_div_sdpa_script ),
111
+ ("pre_mul" , _masked_pre_mul_sdpa_script ),
112
+ ("post_div" , _masked_post_div_sdpa_script ),
113
+ ("post_mul" , _masked_post_mul_sdpa_script ),
114
+ ]
115
+ )
116
+ def test_sdpa_fusion (self , name , script_func ):
117
+ test_case = SDPATestCase (script_func )
118
+ model = test_case .get_onnx_model ()
69
119
onnxscript .optimizer .optimize (model )
70
120
71
- # inputs = test .get_ort_inputs()
121
+ # inputs = test_case .get_ort_inputs()
72
122
# original_outputs = ort_run("original", model, inputs)
73
123
74
124
count = fuse_sdpa (model )
0 commit comments