26
26
MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
27
27
SQRT_SCALE_FACTOR = math .sqrt (SCALE_FACTOR )
28
28
SQRT_MUL_SCALE_FACTOR = math .sqrt (MUL_SCALE_FACTOR )
29
+ CUSTOM_SCALE_FACTOR = 2.0
29
30
30
31
31
32
@script ()
@@ -74,6 +75,65 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
74
75
return attn_output
75
76
76
77
78
+ @script ()
79
+ def _custom_scale_pre_div_sdpa_script (query , key , value ):
80
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
81
+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
82
+ scaled_query = op .Div (query , divisor )
83
+ scaled_key = op .Div (key_transposed , divisor )
84
+ attn_score = op .MatMul (scaled_query , scaled_key )
85
+ attn_weight = op .Softmax (attn_score , axis = - 1 )
86
+ attn_output = op .MatMul (attn_weight , value )
87
+ return attn_output
88
+
89
+
90
+ @script ()
91
+ def _custom_scale_pre_mul_sdpa_script (query , key , value ):
92
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
93
+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
94
+ scaled_query = op .Mul (query , multiplier )
95
+ scaled_key = op .Mul (key_transposed , multiplier )
96
+ attn_score = op .MatMul (scaled_query , scaled_key )
97
+ attn_weight = op .Softmax (attn_score , axis = - 1 )
98
+ attn_output = op .MatMul (attn_weight , value )
99
+ return attn_output
100
+
101
+
102
+ @script ()
103
+ def _custom_multi_scale_pre_mul_sdpa_script (query , key , value ):
104
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
105
+ multiplier_q = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
106
+ multiplier_k = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
107
+ scaled_query = op .Mul (query , multiplier_q )
108
+ scaled_key = op .Mul (key_transposed , multiplier_k )
109
+ attn_score = op .MatMul (scaled_query , scaled_key )
110
+ attn_weight = op .Softmax (attn_score , axis = - 1 )
111
+ attn_output = op .MatMul (attn_weight , value )
112
+ return attn_output
113
+
114
+
115
+ @script ()
116
+ def _custom_scale_post_div_sdpa_script (query , key , value ):
117
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
118
+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
119
+ attn_score = op .MatMul (query , key_transposed )
120
+ scaled_attn_score = op .Div (attn_score , divisor )
121
+ attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
122
+ attn_output = op .MatMul (attn_weight , value )
123
+ return attn_output
124
+
125
+
126
+ @script ()
127
+ def _custom_scale_post_mul_sdpa_script (query , key , value ):
128
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
129
+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
130
+ attn_score = op .MatMul (query , key_transposed )
131
+ scaled_attn_score = op .Mul (attn_score , multiplier )
132
+ attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
133
+ attn_output = op .MatMul (attn_weight , value )
134
+ return attn_output
135
+
136
+
77
137
@script ()
78
138
def _masked_pre_div_sdpa_script (query , key , value , mask ):
79
139
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
@@ -124,6 +184,56 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
124
184
return attn_output
125
185
126
186
187
+ @script ()
188
+ def _custom_scale_pre_div_sdpa_script (query , key , value , mask ):
189
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
190
+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
191
+ scaled_query = op .Div (query , divisor )
192
+ scaled_key = op .Div (key_transposed , divisor )
193
+ attn_score = op .MatMul (scaled_query , scaled_key )
194
+ masked_attn_score = op .Add (attn_score , mask )
195
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
196
+ attn_output = op .MatMul (attn_weight , value )
197
+ return attn_output
198
+
199
+
200
+ @script ()
201
+ def _custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
202
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
203
+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
204
+ scaled_query = op .Mul (query , multiplier )
205
+ scaled_key = op .Mul (key_transposed , multiplier )
206
+ attn_score = op .MatMul (scaled_query , scaled_key )
207
+ masked_attn_score = op .Add (attn_score , mask )
208
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
209
+ attn_output = op .MatMul (attn_weight , value )
210
+ return attn_output
211
+
212
+
213
+ @script ()
214
+ def _custom_scale_post_div_sdpa_script (query , key , value , mask ):
215
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
216
+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
217
+ attn_score = op .MatMul (query , key_transposed )
218
+ scaled_attn_score = op .Div (attn_score , divisor )
219
+ masked_attn_score = op .Add (scaled_attn_score , mask )
220
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
221
+ attn_output = op .MatMul (attn_weight , value )
222
+ return attn_output
223
+
224
+
225
+ @script ()
226
+ def _custom_scale_post_mul_sdpa_script (query , key , value , mask ):
227
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
228
+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
229
+ attn_score = op .MatMul (query , key_transposed )
230
+ scaled_attn_score = op .Mul (attn_score , multiplier )
231
+ masked_attn_score = op .Add (scaled_attn_score , mask )
232
+ attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
233
+ attn_output = op .MatMul (attn_weight , value )
234
+ return attn_output
235
+
236
+
127
237
class SDPATestCase :
128
238
def __init__ (self , script_func ):
129
239
self .script_func = script_func
@@ -161,6 +271,18 @@ class TestSDPAFusion(unittest.TestCase):
161
271
("pre_mul" , _masked_pre_mul_sdpa_script ),
162
272
("post_div" , _masked_post_div_sdpa_script ),
163
273
("post_mul" , _masked_post_mul_sdpa_script ),
274
+ ("custom_scale_post_mul" , _custom_scale_post_mul_sdpa_script ),
275
+ ("custom_scale_post_div" , _custom_scale_post_div_sdpa_script ),
276
+ ("custom_scale_pre_mul" , _custom_scale_pre_mul_sdpa_script ),
277
+ ("custom_scale_pre_div" , _custom_scale_pre_div_sdpa_script ),
278
+ ("custom_scale_post_mul_masked" , _custom_scale_post_mul_sdpa_script ),
279
+ ("custom_scale_post_div_masked" , _custom_scale_post_div_sdpa_script ),
280
+ ("custom_scale_pre_mul_masked" , _custom_scale_pre_mul_sdpa_script ),
281
+ ("custom_scale_pre_div_masked" , _custom_scale_pre_div_sdpa_script ),
282
+ (
283
+ "_custom_multi_scale_pre_mul_sdpa_script" ,
284
+ _custom_multi_scale_pre_mul_sdpa_script ,
285
+ ),
164
286
]
165
287
)
166
288
def test_sdpa_fusion (self , name , script_func ):
@@ -178,6 +300,24 @@ def test_sdpa_fusion(self, name, script_func):
178
300
op_types = [n .op_type for n in model .graph ]
179
301
self .assertIn ("SDPA" , op_types )
180
302
303
+ # Ensure that the scale of the SDPA node is set correctly
304
+ sdpa_node = next (n for n in model .graph if n .op_type == "SDPA" )
305
+ self .assertEqual (sdpa_node .op_type , "SDPA" )
306
+
307
+ if "custom" in name :
308
+ self .assertIsNotNone (sdpa_node .attributes .get ("scale" ))
309
+ scale_factor = sdpa_node .attributes ["scale" ].value
310
+ self .assertIsNotNone (scale_factor )
311
+ if "pre" in name :
312
+ self .assertEqual (scale_factor , CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR )
313
+ elif "post" in name :
314
+ self .assertEqual (scale_factor , CUSTOM_SCALE_FACTOR )
315
+ else :
316
+ # These tests are for the default scaling factors, no scale factor is passed to SDPA
317
+ # pattern rewriting check functions should be sufficient to check if expected value
318
+ # of scale_factor (is =default_scaling_factor)
319
+ self .assertIsNone (sdpa_node .attributes .get ("scale" ))
320
+
181
321
# new_outputs = ort_run("optimized", model, inputs)
182
322
# assert_allclose(new_outputs, original_outputs)
183
323
0 commit comments