@@ -40,7 +40,6 @@ def __init__(
40
40
transpose_4d : bool ,
41
41
pre_scale_q : bool ,
42
42
is_rotary : bool ,
43
- use_mask : bool ,
44
43
has_past_present : bool ,
45
44
is_cross_attention : bool ,
46
45
):
@@ -49,7 +48,6 @@ def __init__(
49
48
self ._transpose_4d = transpose_4d
50
49
self ._pre_scale_q = pre_scale_q
51
50
self ._is_rotary = is_rotary
52
- self ._use_mask = use_mask
53
51
self ._has_past_present = has_past_present
54
52
self ._is_cross_attention = is_cross_attention
55
53
@@ -59,13 +57,11 @@ def pattern(
59
57
query_BSD ,
60
58
key ,
61
59
value ,
62
- mask ,
63
60
past_key ,
64
61
past_value ,
65
62
position_ids ,
66
63
cos ,
67
64
sin ,
68
- key_perm ,
69
65
q_scale ,
70
66
):
71
67
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
@@ -80,6 +76,12 @@ def pattern(
80
76
if not self ._is_cross_attention :
81
77
# Reshape from (B, S, D) to (B, S, H, D/H)
82
78
key = op .Reshape (key , pattern .ANY_VALUE , _outputs = ["key_BSHDh" ])
79
+ # Key may or may not be transposed at this point, based on usage pattern
80
+ key = pattern .OrValue (
81
+ [op .Transpose (key , perm = [0 , 2 , 1 , 3 ]), key ],
82
+ tag_var = "key_transposed" ,
83
+ tag_values = [True , False ],
84
+ )
83
85
84
86
# Reshape from (B, S, D) to (B, S, H, D/H)
85
87
value_BSHDh = op .Reshape (value , pattern .ANY_VALUE , _outputs = ["value_BSHDh" ])
@@ -133,22 +135,14 @@ def pattern(
133
135
key_seq_to_sdpa = key_seq
134
136
value_seq_to_sdpa = value_seq
135
137
136
- # TODO: Remove use_mask once SDPA op is usable
137
- if self ._use_mask :
138
- sdpa = op .SDPA (
139
- query_BHSDh_emb ,
140
- key_seq_to_sdpa ,
141
- value_seq_to_sdpa ,
142
- mask ,
143
- _domain = "ai.onnxruntime._fusion" ,
144
- )
145
- else :
146
- sdpa = op .SDPA (
147
- query_BHSDh_emb ,
148
- key_seq_to_sdpa ,
149
- value_seq_to_sdpa ,
150
- _domain = "ai.onnxruntime._fusion" ,
151
- )
138
+ sdpa = op .SDPA (
139
+ query_BHSDh_emb ,
140
+ key_seq_to_sdpa ,
141
+ value_seq_to_sdpa ,
142
+ _allow_other_inputs = True ,
143
+ _outputs = ["sdpa_output" ],
144
+ _domain = "ai.onnxruntime._fusion" ,
145
+ )
152
146
153
147
# Transpose attention back to (B, S, H, D/H)
154
148
attention_transposed = op .Transpose (sdpa , perm = [0 , 2 , 1 , 3 ])
@@ -167,17 +161,19 @@ def check(
167
161
query_BSD ,
168
162
key ,
169
163
value ,
170
- mask ,
164
+ sdpa_output ,
171
165
past_key ,
172
166
past_value ,
173
- key_perm ,
174
167
query_BSHDh ,
168
+ key_transposed = None ,
175
169
key_BSHDh = None ,
176
170
value_BSHDh = None ,
177
171
** _ ,
178
172
) -> pattern .MatchResult : # type: ignore[name-defined]
179
173
check_result = pattern .MatchResult ()
180
174
175
+ sdpa_node = sdpa_output .producer ()
176
+
181
177
bindings : dict [str , Dim ] = {}
182
178
183
179
def no_match (val : ir .Value , dims : Sequence [str ]) -> bool :
@@ -223,6 +219,13 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
223
219
f"Shape mismatch: { key } does not match expected dimensions ['B', 'Skv', 'D']" ,
224
220
query_BSD ,
225
221
)
222
+ sdpa_key_format = sdpa_node .attributes .get_string ("key_format" )
223
+ expected_key_format = "BHSd" if key_transposed else "BSHd"
224
+ if sdpa_key_format != expected_key_format :
225
+ return check_result .fail (
226
+ f"Unexpected key format: { sdpa_key_format } . Expected: { expected_key_format } " ,
227
+ sdpa_node ,
228
+ )
226
229
if no_match (value , ["B" , "Skv" , "D" ]):
227
230
return check_result .fail (
228
231
f"Shape mismatch: { value } does not match expected dimensions ['B', 'Skv', 'D']" ,
@@ -245,7 +248,11 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
245
248
# ORT's contrib ops (MHA, Attention) allow a mask of shape (1 or B, 1 or H, S, St)
246
249
# That is: broadcast allowed only for the first two dimensions. (Even that is not
247
250
# supported by some earlier versions of ORT, which are not supported here.)
248
- if self ._use_mask :
251
+ mask = None
252
+ if len (sdpa_node .inputs ) > 3 :
253
+ mask = sdpa_node .inputs [3 ]
254
+ self .mask = mask
255
+ if mask is not None :
249
256
if (mask_shape := mask .shape ) is None :
250
257
return check_result .fail (
251
258
"Mask shape cannot be determined." ,
@@ -293,7 +300,6 @@ def rewrite(
293
300
query_BSD ,
294
301
key ,
295
302
value ,
296
- mask ,
297
303
past_key ,
298
304
past_value ,
299
305
query_BSHDh ,
@@ -335,6 +341,7 @@ def rewrite(
335
341
query_BSD_emb = query_BSD
336
342
key_BSD_emb = key
337
343
344
+ mask = self .mask
338
345
if self ._use_mask_broadcast :
339
346
one = op .Constant (value_ints = [1 ])
340
347
S = op .Shape (query_BSD , start = 1 , end = 2 )
@@ -365,7 +372,6 @@ def _make_rule_set(has_past_present: bool):
365
372
"transpose_4d" : transpose_4d ,
366
373
"pre_scale_q" : pre_scale_q ,
367
374
"is_rotary" : is_rotary ,
368
- "use_mask" : use_mask ,
369
375
"has_past_present" : has_past_present ,
370
376
"is_cross_attention" : is_cross_attention ,
371
377
}
@@ -375,7 +381,6 @@ def _make_rule_set(has_past_present: bool):
375
381
) # Only generate patterns when double_transpose is True
376
382
for pre_scale_q in [True , False ]
377
383
for is_rotary in [False , True ]
378
- for use_mask in [False , True ]
379
384
for is_cross_attention in ([False ] if has_past_present else [False , True ])
380
385
]
381
386
@@ -387,7 +392,6 @@ def _make_rule_set(has_past_present: bool):
387
392
f"{ '_Twice' if params ['double_transpose' ] else '' } "
388
393
f"{ '_PreScaleQ' if params ['pre_scale_q' ] else '' } "
389
394
f"{ '_Rotary' if params ['is_rotary' ] else '' } "
390
- f"{ '_Masked' if params ['use_mask' ] else '' } "
391
395
f"{ '_Past' if params ['has_past_present' ] else '' } "
392
396
f"{ '_CrossAttention' if params ['is_cross_attention' ] else '' } " ,
393
397
** params ,
0 commit comments