@@ -42,7 +42,7 @@ def __init__(
42
42
is_rotary : bool ,
43
43
use_mask : bool ,
44
44
has_past_present : bool ,
45
- is_cross_attention_from_past : bool ,
45
+ is_cross_attention : bool ,
46
46
):
47
47
super ().__init__ (name )
48
48
self ._double_transpose = double_transpose
@@ -51,20 +51,14 @@ def __init__(
51
51
self ._is_rotary = is_rotary
52
52
self ._use_mask = use_mask
53
53
self ._has_past_present = has_past_present
54
- # Checks for cross-attention pattern when cross
55
- # query and key originate from past_key and past_value.
56
- self ._is_cross_attention_from_past = is_cross_attention_from_past
57
- # Store the key/value to check if the cross-attention is
58
- # indeed from past_key and past_value.
59
- self ._k_from_past = None
60
- self ._v_from_past = None
54
+ self ._is_cross_attention = is_cross_attention
61
55
62
56
def pattern (
63
57
self ,
64
58
op ,
65
59
query_BSD ,
66
- key_BSD ,
67
- value_BSD ,
60
+ key ,
61
+ value ,
68
62
mask ,
69
63
past_key ,
70
64
past_value ,
@@ -83,23 +77,28 @@ def pattern(
83
77
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
84
78
query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
85
79
86
- # Reshape from (B, S, D) to (B, S, H, D/H)
87
- key_BSHDh = op .Reshape (key_BSD , pattern .ANY_VALUE , _outputs = ["key_BSHDh" ])
88
-
89
- # Possible Transpose patterns for key:
90
- # This scenario optimizes the need for a double transpose
91
- # 1. (B, S, H, D/H) -> (B, H, D/H, S)
92
- # Patterns with double transpose of key
93
- # Double transpose should handle this optimization
94
- # 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
95
- # Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
96
- # 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
97
- key_BHSDh = op .Transpose (key_BSHDh , perm = key_perm )
98
-
99
- # Reshape from (B, S, D) to (B, S, H, D/H)
100
- value_BSHDh = op .Reshape (value_BSD , pattern .ANY_VALUE , _outputs = ["value_BSHDh" ])
101
- # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
102
- value_BHSDh = op .Transpose (value_BSHDh , perm = [0 , 2 , 1 , 3 ])
80
+ if not self ._is_cross_attention :
81
+ # Reshape from (B, S, D) to (B, S, H, D/H)
82
+ key_BSHDh = op .Reshape (key , pattern .ANY_VALUE , _outputs = ["key_BSHDh" ])
83
+
84
+ # Possible Transpose patterns for key:
85
+ # This scenario optimizes the need for a double transpose
86
+ # 1. (B, S, H, D/H) -> (B, H, D/H, S)
87
+ # Patterns with double transpose of key
88
+ # Double transpose should handle this optimization
89
+ # 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
90
+ # Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
91
+ # 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
92
+ key_BHSDh = op .Transpose (key_BSHDh , perm = key_perm )
93
+
94
+ # Reshape from (B, S, D) to (B, S, H, D/H)
95
+ value_BSHDh = op .Reshape (value , pattern .ANY_VALUE , _outputs = ["value_BSHDh" ])
96
+ # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
97
+ value_BHSDh = op .Transpose (value_BSHDh , perm = [0 , 2 , 1 , 3 ])
98
+ else :
99
+ # For cross-attention, key and value are not reshaped
100
+ key_BHSDh = key
101
+ value_BHSDh = value
103
102
104
103
if self ._is_rotary :
105
104
# This is workaround for examples where there is a duplication of Unsqueeze op
@@ -117,9 +116,12 @@ def pattern(
117
116
query_BHSDh_emb = op .RotaryEmbedding (
118
117
query_BHSDh , position_ids_q , cos , sin , _domain = "com.microsoft"
119
118
)
120
- key_BHSDh_emb = op .RotaryEmbedding (
121
- key_BHSDh , position_ids_k , cos , sin , _domain = "com.microsoft"
122
- )
119
+ if not self ._is_cross_attention :
120
+ key_BHSDh_emb = op .RotaryEmbedding (
121
+ key_BHSDh , position_ids_k , cos , sin , _domain = "com.microsoft"
122
+ )
123
+ else :
124
+ key_BHSDh_emb = key_BHSDh
123
125
else :
124
126
# If rotary embedding is not used, we fuse with positional_embeddings
125
127
query_BHSDh_emb = query_BHSDh
@@ -130,23 +132,13 @@ def pattern(
130
132
if self ._has_past_present :
131
133
key_seq = op .Concat (past_key , key_BHSDh_emb , axis = - 2 )
132
134
else :
133
- # For patterns where cross-attention key/value originates from past_key/past_value
134
- if self ._is_cross_attention_from_past :
135
- key_seq = past_key
136
- self ._k_from_past = key_seq
137
- else :
138
- key_seq = key_BHSDh_emb
135
+ key_seq = key_BHSDh_emb
139
136
140
137
# Concatenate past_value cache and current value
141
138
if self ._has_past_present :
142
139
value_seq = op .Concat (past_value , value_BHSDh , axis = - 2 )
143
140
else :
144
- # For patterns where cross-attention key/value originates from past_key/past_value
145
- if self ._is_cross_attention_from_past :
146
- value_seq = past_value
147
- self ._v_from_past = value_seq
148
- else :
149
- value_seq = value_BHSDh
141
+ value_seq = value_BHSDh
150
142
151
143
# Key/value to be used for dot-product attention computation
152
144
key_seq_to_sdpa = key_seq
@@ -198,8 +190,8 @@ def check(
198
190
self ,
199
191
op ,
200
192
query_BSD ,
201
- key_BSD ,
202
- value_BSD ,
193
+ key ,
194
+ value ,
203
195
mask ,
204
196
past_key ,
205
197
past_value ,
@@ -221,97 +213,57 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
221
213
f"Shape mismatch: { query_BSD } does not match expected dimensions ['B', 'S', 'D']" ,
222
214
query_BSD ,
223
215
)
224
- # If cross-attention key/value originates from past_key/past_value,
225
- # Check if their producer is None, this is done to avoid from the matcher assuming
226
- # that if a key/value pattern path does not exist, it is a cross-attention pattern.
227
- if self ._is_cross_attention_from_past :
228
- if self ._k_from_past is not None :
229
- if self ._k_from_past .producer () is not None :
230
- return check_result .fail (
231
- "Key is not from past_key/past_value. This is not a cross-attention pattern." ,
232
- )
233
- if self ._v_from_past is not None :
234
- if self ._v_from_past .producer () is not None :
235
- return check_result .fail (
236
- "Value is not from past_key/past_value. This is not a cross-attention pattern." ,
237
- )
238
- # We only consider patterns where,
239
- # 1) double_transpose = True, to avoid pattern consuming the transpose of key.
240
- # 2) is_rotary = False, as if rotary embeddings are used, the key is not from past_key.
241
- # TODO: Determine what parameter conditions would make this pattern full-proof.
242
- if not self ._double_transpose or self ._is_rotary :
243
- return check_result .fail (
244
- "Key is not from past_key/past_value. This is not a cross-attention pattern." ,
245
- )
246
216
247
- """
248
- # Check for key transpose values
249
- k_perm = _ir_utils.get_singleton_value(key_perm)
250
- if k_perm is None or not isinstance(k_perm, list):
251
- return check_result.fail(
252
- f"Key permutation is not a list.",
253
- key_perm,
254
- )
255
- if len(k_perm) != 4:
217
+ if no_match (query_BSHDh , ["B" , "S" , "H" , "Dh" ]):
256
218
return check_result .fail (
257
- f"Key permutation is not of length 4. ",
258
- key_perm ,
219
+ f"Shape mismatch: { query_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh'] " ,
220
+ query_BSHDh ,
259
221
)
260
- if self._double_transpose:
261
- if k_perm != [0, 2, 1, 3]:
222
+ # If cross-attention key/value shapes are 4D
223
+ if self ._is_cross_attention :
224
+ if no_match (key , ["B" , "H" , "Skv" , "Dh" ]):
262
225
return check_result .fail (
263
- f"Key permutation is not [0, 2, 1, 3]. ",
264
- key_perm ,
226
+ f"Shape mismatch: { key } does not match expected dimensions ['B', 'H', 'Skv', 'Dh'] " ,
227
+ key ,
265
228
)
266
- else:
267
- if k_perm != [0, 2, 3, 1]:
229
+ if no_match (value , ["B" , "H" , "Skv" , "Dv" ]):
268
230
return check_result .fail (
269
- f"Key permutation is not [0, 2, 3, 1]. ",
270
- key_perm ,
231
+ f"Shape mismatch: { value } does not match expected dimensions ['B', 'H', 'Skv', 'Dv'] " ,
232
+ value ,
271
233
)
272
- """
273
-
274
- if not self ._is_cross_attention_from_past :
275
- if no_match (key_BSD , ["B" , "Skv" , "D" ]):
234
+ # Ensure that no past_key/past_value is used in cross-attention
235
+ if past_key is not None :
276
236
return check_result .fail (
277
- f"Shape mismatch: { key_BSD } does not match expected dimensions ['B', 'Skv', 'D']" ,
278
- query_BSD ,
279
- )
280
- if no_match (value_BSD , ["B" , "Skv" , "D" ]):
281
- return check_result .fail (
282
- f"Shape mismatch: { value_BSD } does not match expected dimensions ['B', 'Skv', 'D']" ,
283
- value_BSD ,
284
- )
285
-
286
- if self ._has_past_present :
287
- if no_match (past_key , ["B" , "H" , "Spast" , "Dh" ]):
288
- return check_result .fail (
289
- f"Shape mismatch: { past_key } does not match expected dimensions ['B', 'H', 'Spast', 'Dh']" ,
237
+ "past_key should be None in cross-attention." ,
290
238
past_key ,
291
239
)
292
- if no_match ( past_value , [ "B" , "H" , "Spast" , "Dv" ]) :
240
+ if past_value is not None :
293
241
return check_result .fail (
294
- f"Shape mismatch: { past_value } does not match expected dimensions ['B', 'H', 'Spast', 'Dv'] " ,
242
+ " past_value should be None in cross-attention. " ,
295
243
past_value ,
296
244
)
297
-
298
- if no_match (query_BSHDh , ["B" , "S" , "H" , "Dh" ]):
299
- return check_result .fail (
300
- f"Shape mismatch: { query_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']" ,
301
- query_BSHDh ,
302
- )
303
-
304
- if not self ._is_cross_attention_from_past :
305
- if key_BSHDh and no_match (key_BSHDh , ["B" , "S" , "H" , "Dh" ]):
245
+ else :
246
+ if no_match (key , ["B" , "Skv" , "D" ]):
306
247
return check_result .fail (
307
- f"Shape mismatch: { key_BSHDh } does not match expected dimensions ['B', 'S ', 'H', 'Dh ']" ,
308
- query_BSHDh ,
248
+ f"Shape mismatch: { key } does not match expected dimensions ['B', 'Skv ', 'D ']" ,
249
+ query_BSD ,
309
250
)
310
- if value_BSHDh and no_match (value_BSHDh , ["B" , "S " , "H" , "Dh " ]):
251
+ if no_match (value , ["B" , "Skv " , "D " ]):
311
252
return check_result .fail (
312
- f"Shape mismatch: { value_BSHDh } does not match expected dimensions ['B', 'S ', 'H', 'Dh ']" ,
313
- query_BSHDh ,
253
+ f"Shape mismatch: { value } does not match expected dimensions ['B', 'Skv ', 'D ']" ,
254
+ value ,
314
255
)
256
+ if self ._has_past_present :
257
+ if no_match (past_key , ["B" , "H" , "Spast" , "Dh" ]):
258
+ return check_result .fail (
259
+ f"Shape mismatch: { past_key } does not match expected dimensions ['B', 'H', 'Spast', 'Dh']" ,
260
+ past_key ,
261
+ )
262
+ if no_match (past_value , ["B" , "H" , "Spast" , "Dv" ]):
263
+ return check_result .fail (
264
+ f"Shape mismatch: { past_value } does not match expected dimensions ['B', 'H', 'Spast', 'Dv']" ,
265
+ past_value ,
266
+ )
315
267
316
268
# TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
317
269
# But this also, unforunately, depends on ORT version.
@@ -326,8 +278,8 @@ def rewrite(
326
278
self ,
327
279
op ,
328
280
query_BSD ,
329
- key_BSD ,
330
- value_BSD ,
281
+ key ,
282
+ value ,
331
283
mask ,
332
284
past_key ,
333
285
past_value ,
@@ -353,35 +305,21 @@ def rewrite(
353
305
query_BSD_emb = op .RotaryEmbedding (
354
306
query_BSD , position_ids , cos , sin , _domain = "com.microsoft"
355
307
)
356
- key_BSD_emb = op .RotaryEmbedding (
357
- key_BSD , position_ids , cos , sin , _domain = "com.microsoft"
358
- )
308
+ if not self ._is_cross_attention :
309
+ key_BSD_emb = op .RotaryEmbedding (
310
+ key , position_ids , cos , sin , _domain = "com.microsoft"
311
+ )
312
+ else :
313
+ key_BSD_emb = key
359
314
else :
360
315
query_BSD_emb = query_BSD
361
- key_BSD_emb = key_BSD
316
+ key_BSD_emb = key
362
317
363
318
num_outputs = 1 + (2 * self ._has_past_present )
364
- # Special case for cross-attention that comes from past_key/past_value
365
- if self ._is_cross_attention_from_past :
366
- return op .MultiHeadAttention (
367
- query_BSD_emb ,
368
- past_key ,
369
- past_value ,
370
- None , # bias
371
- None , # key padding mask
372
- mask , # attention mask/bias
373
- None ,
374
- None ,
375
- num_heads = num_heads ,
376
- scale = scale ,
377
- _domain = "com.microsoft" ,
378
- _outputs = num_outputs ,
379
- )
380
-
381
319
return op .MultiHeadAttention (
382
320
query_BSD_emb ,
383
321
key_BSD_emb ,
384
- value_BSD ,
322
+ value ,
385
323
None , # bias
386
324
None , # key padding mask
387
325
mask , # attention mask/bias
@@ -402,7 +340,7 @@ def rewrite(
402
340
"is_rotary" : is_rotary ,
403
341
"use_mask" : use_mask ,
404
342
"has_past_present" : has_past_present ,
405
- "is_cross_attention_from_past " : is_cross_attention_from_past ,
343
+ "is_cross_attention " : is_cross_attention ,
406
344
}
407
345
for double_transpose in [False , True ]
408
346
for transpose_4d in (
@@ -411,9 +349,9 @@ def rewrite(
411
349
for pre_scale_q in [True , False ]
412
350
for is_rotary in [False , True ]
413
351
for use_mask in [False , True ]
414
- # TODO: Avoid this parameter from being order dependent
352
+ # Enforce has_past_present to be True first, to avoid missing the pattern
415
353
for has_past_present in [True , False ]
416
- for is_cross_attention_from_past in [False , True ]
354
+ for is_cross_attention in [False , True ]
417
355
]
418
356
419
357
# Dynamically create the rules
@@ -426,7 +364,7 @@ def rewrite(
426
364
f"{ '_Rotary' if params ['is_rotary' ] else '' } "
427
365
f"{ '_Masked' if params ['use_mask' ] else '' } "
428
366
f"{ '_Past' if params ['has_past_present' ] else '' } "
429
- f"{ '_CrossAttentionFromPast ' if params ['is_cross_attention_from_past ' ] else '' } " ,
367
+ f"{ '_CrossAttention ' if params ['is_cross_attention ' ] else '' } " ,
430
368
** params ,
431
369
)
432
370
for params in parameter_combinations
0 commit comments