@@ -36,23 +36,28 @@ def __init__(
36
36
self ,
37
37
name ,
38
38
* ,
39
+ double_transpose : bool ,
39
40
transpose_4d : bool ,
40
41
pre_scale_q : bool ,
41
42
is_rotary : bool ,
42
43
use_mask : bool ,
43
44
has_past_present : bool ,
44
- is_cross_attention : bool ,
45
+ is_cross_attention_from_past : bool ,
45
46
):
46
47
super ().__init__ (name )
48
+ self ._double_transpose = double_transpose
47
49
self ._transpose_4d = transpose_4d
48
50
self ._pre_scale_q = pre_scale_q
49
51
self ._is_rotary = is_rotary
50
52
self ._use_mask = use_mask
51
53
self ._has_past_present = has_past_present
52
- # Currently, we only support cross-attention when cross
54
+ # Checks for cross-attention pattern when cross
53
55
# query and key originate from past_key and past_value.
54
- # TODO: Support patterns where any key/value can be used for cross-attention.
55
- self ._is_cross_attention = is_cross_attention
56
+ self ._is_cross_attention_from_past = is_cross_attention_from_past
57
+ # Store the key/value to check if the cross-attention is
58
+ # indeed from past_key and past_value.
59
+ self ._k_from_past = None
60
+ self ._v_from_past = None
56
61
57
62
def pattern (
58
63
self ,
@@ -66,6 +71,7 @@ def pattern(
66
71
position_ids ,
67
72
cos ,
68
73
sin ,
74
+ key_perm ,
69
75
q_scale ,
70
76
):
71
77
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
@@ -80,12 +86,15 @@ def pattern(
80
86
# Reshape from (B, S, D) to (B, S, H, D/H)
81
87
key_BSHDh = op .Reshape (key_BSD , pattern .ANY_VALUE , _outputs = ["key_BSHDh" ])
82
88
83
- # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
84
- # TODO: Fix condition
85
- if not self ._is_cross_attention and self ._has_past_present :
86
- key_BHSDh = op .Transpose (key_BSHDh , perm = [0 , 2 , 1 , 3 ])
87
- else :
88
- key_BHSDh = op .Transpose (key_BSHDh , perm = [0 , 2 , 3 , 1 ])
89
+ # Possible Transpose patterns for key:
90
+ # This scenario optimizes the need for a double transpose
91
+ # 1. (B, S, H, D/H) -> (B, H, D/H, S)
92
+ # Patterns with double transpose of key
93
+ # Double transpose should handle this optimization
94
+ # 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
95
+ # Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
96
+ # 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
97
+ key_BHSDh = op .Transpose (key_BSHDh , perm = key_perm )
89
98
90
99
# Reshape from (B, S, D) to (B, S, H, D/H)
91
100
value_BSHDh = op .Reshape (value_BSD , pattern .ANY_VALUE , _outputs = ["value_BSHDh" ])
@@ -119,41 +128,43 @@ def pattern(
119
128
# Concatenate past_key cache and current key, and transpose to enable
120
129
# dot-product attention computation.
121
130
if self ._has_past_present :
131
+ key_seq = op .Concat (past_key , key_BHSDh_emb , axis = - 2 )
132
+ else :
122
133
# For patterns where cross-attention key/value originates from past_key/past_value
123
- if self ._is_cross_attention :
134
+ if self ._is_cross_attention_from_past :
124
135
key_seq = past_key
136
+ self ._k_from_past = key_seq
125
137
else :
126
- key_seq = op .Concat (past_key , key_BHSDh_emb , axis = - 2 )
127
- else :
128
- key_seq = key_BHSDh_emb
138
+ key_seq = key_BHSDh_emb
129
139
130
140
# Concatenate past_value cache and current value
131
141
if self ._has_past_present :
142
+ value_seq = op .Concat (past_value , value_BHSDh , axis = - 2 )
143
+ else :
132
144
# For patterns where cross-attention key/value originates from past_key/past_value
133
- if self ._is_cross_attention :
145
+ if self ._is_cross_attention_from_past :
134
146
value_seq = past_value
147
+ self ._v_from_past = value_seq
135
148
else :
136
- value_seq = op .Concat (past_value , value_BHSDh , axis = - 2 )
137
- else :
138
- value_seq = value_BHSDh
149
+ value_seq = value_BHSDh
139
150
140
151
# Key/value to be used for dot-product attention computation
141
152
key_seq_to_sdpa = key_seq
142
153
value_seq_to_sdpa = value_seq
143
154
144
155
# Transpose last two axes of key_seq to compute dot-product via matmul.
145
- if self ._transpose_4d :
146
- if self ._has_past_present :
156
+ if self ._double_transpose :
157
+ if self ._transpose_4d :
147
158
key_seq_to_sdpa = op .Transpose (key_seq_to_sdpa , perm = [0 , 1 , 3 , 2 ])
148
- else :
149
- # Transpose after converting to 3D
150
- key_seq_BH_Skv_Dh = op .Reshape (
151
- key_seq_to_sdpa , pattern .ANY_VALUE , _outputs = ["key_seq_BH_Skv_Dh" ]
152
- )
153
- key_seq_BH_Dh_Skv = op .Transpose (key_seq_BH_Skv_Dh , perm = [0 , 2 , 1 ])
154
- key_seq_to_sdpa = op .Reshape (
155
- key_seq_BH_Dh_Skv , pattern .ANY_VALUE , _outputs = ["key_seq_B_H_Dh_Skv" ]
156
- )
159
+ else :
160
+ # Transpose after converting to 3D
161
+ key_seq_BH_Skv_Dh = op .Reshape (
162
+ key_seq_to_sdpa , pattern .ANY_VALUE , _outputs = ["key_seq_BH_Skv_Dh" ]
163
+ )
164
+ key_seq_BH_Dh_Skv = op .Transpose (key_seq_BH_Skv_Dh , perm = [0 , 2 , 1 ])
165
+ key_seq_to_sdpa = op .Reshape (
166
+ key_seq_BH_Dh_Skv , pattern .ANY_VALUE , _outputs = ["key_seq_B_H_Dh_Skv" ]
167
+ )
157
168
158
169
# TODO: Remove use_mask once SDPA op is usable
159
170
if self ._use_mask :
@@ -178,7 +189,7 @@ def pattern(
178
189
attention = op .Reshape (
179
190
attention_transposed , pattern .ANY_VALUE , _outputs = ["attention_reshaped" ]
180
191
)
181
- if self ._has_past_present and not self . _is_cross_attention :
192
+ if self ._has_past_present :
182
193
return attention , key_seq , value_seq
183
194
else :
184
195
return attention
@@ -192,6 +203,7 @@ def check(
192
203
mask ,
193
204
past_key ,
194
205
past_value ,
206
+ key_perm ,
195
207
query_BSHDh ,
196
208
key_BSHDh = None ,
197
209
value_BSHDh = None ,
@@ -209,7 +221,57 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
209
221
f"Shape mismatch: { query_BSD } does not match expected dimensions ['B', 'S', 'D']" ,
210
222
query_BSD ,
211
223
)
212
- if not self ._is_cross_attention :
224
+ # If cross-attention key/value originates from past_key/past_value,
225
+ # Check if their producer is None, this is done to avoid from the matcher assuming
226
+ # that if a key/value pattern path does not exist, it is a cross-attention pattern.
227
+ if self ._is_cross_attention_from_past :
228
+ if self ._k_from_past is not None :
229
+ if self ._k_from_past .producer () is not None :
230
+ return check_result .fail (
231
+ "Key is not from past_key/past_value. This is not a cross-attention pattern." ,
232
+ )
233
+ if self ._v_from_past is not None :
234
+ if self ._v_from_past .producer () is not None :
235
+ return check_result .fail (
236
+ "Value is not from past_key/past_value. This is not a cross-attention pattern." ,
237
+ )
238
+ # We only consider patterns where,
239
+ # 1) double_transpose = True, to avoid pattern consuming the transpose of key.
240
+ # 2) is_rotary = False, as if rotary embeddings are used, the key is not from past_key.
241
+ # TODO: Determine what parameter conditions would make this pattern full-proof.
242
+ if not self ._double_transpose or self ._is_rotary :
243
+ return check_result .fail (
244
+ "Key is not from past_key/past_value. This is not a cross-attention pattern." ,
245
+ )
246
+
247
+ """
248
+ # Check for key transpose values
249
+ k_perm = _ir_utils.get_singleton_value(key_perm)
250
+ if k_perm is None or not isinstance(k_perm, list):
251
+ return check_result.fail(
252
+ f"Key permutation is not a list.",
253
+ key_perm,
254
+ )
255
+ if len(k_perm) != 4:
256
+ return check_result.fail(
257
+ f"Key permutation is not of length 4.",
258
+ key_perm,
259
+ )
260
+ if self._double_transpose:
261
+ if k_perm != [0, 2, 1, 3]:
262
+ return check_result.fail(
263
+ f"Key permutation is not [0, 2, 1, 3].",
264
+ key_perm,
265
+ )
266
+ else:
267
+ if k_perm != [0, 2, 3, 1]:
268
+ return check_result.fail(
269
+ f"Key permutation is not [0, 2, 3, 1].",
270
+ key_perm,
271
+ )
272
+ """
273
+
274
+ if not self ._is_cross_attention_from_past :
213
275
if no_match (key_BSD , ["B" , "Skv" , "D" ]):
214
276
return check_result .fail (
215
277
f"Shape mismatch: { key_BSD } does not match expected dimensions ['B', 'Skv', 'D']" ,
@@ -239,7 +301,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
239
301
query_BSHDh ,
240
302
)
241
303
242
- if not self ._is_cross_attention :
304
+ if not self ._is_cross_attention_from_past :
243
305
if key_BSHDh and no_match (key_BSHDh , ["B" , "S" , "H" , "Dh" ]):
244
306
return check_result .fail (
245
307
f"Shape mismatch: { key_BSHDh } does not match expected dimensions ['B', 'S', 'H', 'Dh']" ,
@@ -298,9 +360,9 @@ def rewrite(
298
360
query_BSD_emb = query_BSD
299
361
key_BSD_emb = key_BSD
300
362
301
- num_outputs = 1 + (2 * self ._has_past_present * ( not self . _is_cross_attention ) )
302
- # Special case for cross-attention
303
- if self ._has_past_present and self . _is_cross_attention :
363
+ num_outputs = 1 + (2 * self ._has_past_present )
364
+ # Special case for cross-attention that comes from past_key/past_value
365
+ if self ._is_cross_attention_from_past :
304
366
return op .MultiHeadAttention (
305
367
query_BSD_emb ,
306
368
past_key ,
@@ -334,31 +396,37 @@ def rewrite(
334
396
335
397
parameter_combinations = [
336
398
{
399
+ "double_transpose" : double_transpose ,
337
400
"transpose_4d" : transpose_4d ,
338
401
"pre_scale_q" : pre_scale_q ,
339
402
"is_rotary" : is_rotary ,
340
403
"use_mask" : use_mask ,
341
404
"has_past_present" : has_past_present ,
342
- "is_cross_attention " : is_cross_attention ,
405
+ "is_cross_attention_from_past " : is_cross_attention_from_past ,
343
406
}
344
- for transpose_4d in [False , True ]
407
+ for double_transpose in [False , True ]
408
+ for transpose_4d in (
409
+ [False , True ] if double_transpose else [False ]
410
+ ) # Only generate patterns when double_transpose is True
345
411
for pre_scale_q in [True , False ]
346
412
for is_rotary in [False , True ]
347
413
for use_mask in [False , True ]
348
- for has_past_present in [False , True ]
349
- for is_cross_attention in [False , True ]
414
+ # TODO: Avoid this parameter from being order dependent
415
+ for has_past_present in [True , False ]
416
+ for is_cross_attention_from_past in [False , True ]
350
417
]
351
418
352
419
# Dynamically create the rules
353
420
mha_rules = pattern .RewriteRuleSet (
354
421
[
355
422
MultiHeadAttention .rule (
356
423
f"MHA_{ '4D' if params ['transpose_4d' ] else '3D' } _Transpose"
424
+ f"{ '_Twice' if params ['double_transpose' ] else '' } "
357
425
f"{ '_PreScaleQ' if params ['pre_scale_q' ] else '' } "
358
426
f"{ '_Rotary' if params ['is_rotary' ] else '' } "
359
427
f"{ '_Masked' if params ['use_mask' ] else '' } "
360
428
f"{ '_Past' if params ['has_past_present' ] else '' } "
361
- f"{ '_CrossAttention ' if params ['is_cross_attention ' ] else '' } " ,
429
+ f"{ '_CrossAttentionFromPast ' if params ['is_cross_attention_from_past ' ] else '' } " ,
362
430
** params ,
363
431
)
364
432
for params in parameter_combinations
0 commit comments