26
26
MUL_SCALE_FACTOR = 1.0 / SCALE_FACTOR
27
27
SQRT_SCALE_FACTOR = math .sqrt (SCALE_FACTOR )
28
28
SQRT_MUL_SCALE_FACTOR = math .sqrt (MUL_SCALE_FACTOR )
29
+ CUSTOM_SCALE_FACTOR = 2.0
29
30
30
31
31
32
@script ()
@@ -77,7 +78,7 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
77
78
@script ()
78
79
def _custom_scale_pre_div_sdpa_script (query , key , value ):
79
80
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
80
- divisor = op .Constant (value_float = 2.0 )
81
+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
81
82
scaled_query = op .Div (query , divisor )
82
83
scaled_key = op .Div (key_transposed , divisor )
83
84
attn_score = op .MatMul (scaled_query , scaled_key )
@@ -89,7 +90,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value):
89
90
@script ()
90
91
def _custom_scale_pre_mul_sdpa_script (query , key , value ):
91
92
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
92
- multiplier = op .Constant (value_float = 0.5 )
93
+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
93
94
scaled_query = op .Mul (query , multiplier )
94
95
scaled_key = op .Mul (key_transposed , multiplier )
95
96
attn_score = op .MatMul (scaled_query , scaled_key )
@@ -101,8 +102,8 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
101
102
@script ()
102
103
def _custom_multi_scale_pre_mul_sdpa_script (query , key , value ):
103
104
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
104
- multiplier_q = op .Constant (value_float = 0.5 )
105
- multiplier_k = op .Constant (value_float = 0.5 )
105
+ multiplier_q = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
106
+ multiplier_k = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
106
107
scaled_query = op .Mul (query , multiplier_q )
107
108
scaled_key = op .Mul (key_transposed , multiplier_k )
108
109
attn_score = op .MatMul (scaled_query , scaled_key )
@@ -114,7 +115,7 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
114
115
@script ()
115
116
def _custom_scale_post_div_sdpa_script (query , key , value ):
116
117
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
117
- divisor = op .Constant (value_float = 0.1 )
118
+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
118
119
attn_score = op .MatMul (query , key_transposed )
119
120
scaled_attn_score = op .Div (attn_score , divisor )
120
121
attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
@@ -125,7 +126,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value):
125
126
@script ()
126
127
def _custom_scale_post_mul_sdpa_script (query , key , value ):
127
128
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
128
- multiplier = op .Constant (value_float = 0.125 )
129
+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
129
130
attn_score = op .MatMul (query , key_transposed )
130
131
scaled_attn_score = op .Mul (attn_score , multiplier )
131
132
attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
@@ -186,7 +187,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
186
187
@script ()
187
188
def _custom_scale_pre_div_sdpa_script (query , key , value , mask ):
188
189
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
189
- divisor = op .Constant (value_float = 2.0 )
190
+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
190
191
scaled_query = op .Div (query , divisor )
191
192
scaled_key = op .Div (key_transposed , divisor )
192
193
attn_score = op .MatMul (scaled_query , scaled_key )
@@ -199,7 +200,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
199
200
@script ()
200
201
def _custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
201
202
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
202
- multiplier = op .Constant (value_float = 0.5 )
203
+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
203
204
scaled_query = op .Mul (query , multiplier )
204
205
scaled_key = op .Mul (key_transposed , multiplier )
205
206
attn_score = op .MatMul (scaled_query , scaled_key )
@@ -212,7 +213,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
212
213
@script ()
213
214
def _custom_scale_post_div_sdpa_script (query , key , value , mask ):
214
215
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
215
- divisor = op .Constant (value_float = 0.1 )
216
+ divisor = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
216
217
attn_score = op .MatMul (query , key_transposed )
217
218
scaled_attn_score = op .Div (attn_score , divisor )
218
219
masked_attn_score = op .Add (scaled_attn_score , mask )
@@ -224,7 +225,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask):
224
225
@script ()
225
226
def _custom_scale_post_mul_sdpa_script (query , key , value , mask ):
226
227
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
227
- multiplier = op .Constant (value_float = 0.125 )
228
+ multiplier = op .Constant (value_float = CUSTOM_SCALE_FACTOR )
228
229
attn_score = op .MatMul (query , key_transposed )
229
230
scaled_attn_score = op .Mul (attn_score , multiplier )
230
231
masked_attn_score = op .Add (scaled_attn_score , mask )
@@ -278,7 +279,10 @@ class TestSDPAFusion(unittest.TestCase):
278
279
("custom_scale_post_div_masked" , _custom_scale_post_div_sdpa_script ),
279
280
("custom_scale_pre_mul_masked" , _custom_scale_pre_mul_sdpa_script ),
280
281
("custom_scale_pre_div_masked" , _custom_scale_pre_div_sdpa_script ),
281
- ("_custom_multi_scale_pre_mul_sdpa_script" , _custom_multi_scale_pre_mul_sdpa_script ),
282
+ (
283
+ "_custom_multi_scale_pre_mul_sdpa_script" ,
284
+ _custom_multi_scale_pre_mul_sdpa_script ,
285
+ ),
282
286
]
283
287
)
284
288
def test_sdpa_fusion (self , name , script_func ):
@@ -296,6 +300,24 @@ def test_sdpa_fusion(self, name, script_func):
296
300
op_types = [n .op_type for n in model .graph ]
297
301
self .assertIn ("SDPA" , op_types )
298
302
303
+ # Ensure that the scale of the SDPA node is set correctly
304
+ sdpa_node = next (n for n in model .graph if n .op_type == "SDPA" )
305
+ self .assertEqual (sdpa_node .op_type , "SDPA" )
306
+ self .assertIsNotNone (sdpa_node .attributes .get ("scale" ))
307
+
308
+ scale_factor = sdpa_node .attributes ["scale" ].value
309
+ self .assertIsNotNone (scale_factor )
310
+ if "custom" in name :
311
+ if "pre" in name :
312
+ self .assertEqual (scale_factor , CUSTOM_SCALE_FACTOR * CUSTOM_SCALE_FACTOR )
313
+ elif "post" in name :
314
+ self .assertEqual (scale_factor , CUSTOM_SCALE_FACTOR )
315
+ else :
316
+ if "div" in name :
317
+ self .assertEqual (scale_factor , SCALE_FACTOR )
318
+ elif "mul" in name :
319
+ self .assertEqual (scale_factor , MUL_SCALE_FACTOR )
320
+
299
321
# new_outputs = ort_run("optimized", model, inputs)
300
322
# assert_allclose(new_outputs, original_outputs)
301
323
0 commit comments