@@ -98,6 +98,19 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
98
98
return attn_output
99
99
100
100
101
+ @script ()
102
+ def _custom_multi_scale_pre_mul_sdpa_script (query , key , value ):
103
+ key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
104
+ multiplier_q = op .Constant (value_float = 0.5 )
105
+ multiplier_k = op .Constant (value_float = 0.5 )
106
+ scaled_query = op .Mul (query , multiplier_q )
107
+ scaled_key = op .Mul (key_transposed , multiplier_k )
108
+ attn_score = op .MatMul (scaled_query , scaled_key )
109
+ attn_weight = op .Softmax (attn_score , axis = - 1 )
110
+ attn_output = op .MatMul (attn_weight , value )
111
+ return attn_output
112
+
113
+
101
114
@script ()
102
115
def _custom_scale_post_div_sdpa_script (query , key , value ):
103
116
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
@@ -265,6 +278,7 @@ class TestSDPAFusion(unittest.TestCase):
265
278
("custom_scale_post_div_masked" , _custom_scale_post_div_sdpa_script ),
266
279
("custom_scale_pre_mul_masked" , _custom_scale_pre_mul_sdpa_script ),
267
280
("custom_scale_pre_div_masked" , _custom_scale_pre_div_sdpa_script ),
281
+ (_custom_multi_scale_pre_mul_sdpa_script , _custom_multi_scale_pre_mul_sdpa_script ),
268
282
]
269
283
)
270
284
def test_sdpa_fusion (self , name , script_func ):
0 commit comments