Skip to content

Commit 0bf38e3

Browse files
rewrite cross attention logic
1 parent 9d04864 commit 0bf38e3

File tree

1 file changed

+84
-146
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+84
-146
lines changed

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 84 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
is_rotary: bool,
4343
use_mask: bool,
4444
has_past_present: bool,
45-
is_cross_attention_from_past: bool,
45+
is_cross_attention: bool,
4646
):
4747
super().__init__(name)
4848
self._double_transpose = double_transpose
@@ -51,20 +51,14 @@ def __init__(
5151
self._is_rotary = is_rotary
5252
self._use_mask = use_mask
5353
self._has_past_present = has_past_present
54-
# Checks for cross-attention pattern when cross
55-
# query and key originate from past_key and past_value.
56-
self._is_cross_attention_from_past = is_cross_attention_from_past
57-
# Store the key/value to check if the cross-attention is
58-
# indeed from past_key and past_value.
59-
self._k_from_past = None
60-
self._v_from_past = None
54+
self._is_cross_attention = is_cross_attention
6155

6256
def pattern(
6357
self,
6458
op,
6559
query_BSD,
66-
key_BSD,
67-
value_BSD,
60+
key,
61+
value,
6862
mask,
6963
past_key,
7064
past_value,
@@ -83,23 +77,28 @@ def pattern(
8377
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
8478
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
8579

86-
# Reshape from (B, S, D) to (B, S, H, D/H)
87-
key_BSHDh = op.Reshape(key_BSD, pattern.ANY_VALUE, _outputs=["key_BSHDh"])
88-
89-
# Possible Transpose patterns for key:
90-
# This scenario optimizes the need for a double transpose
91-
# 1. (B, S, H, D/H) -> (B, H, D/H, S)
92-
# Patterns with double transpose of key
93-
# Double transpose should handle this optimization
94-
# 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
95-
# Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
96-
# 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
97-
key_BHSDh = op.Transpose(key_BSHDh, perm=key_perm)
98-
99-
# Reshape from (B, S, D) to (B, S, H, D/H)
100-
value_BSHDh = op.Reshape(value_BSD, pattern.ANY_VALUE, _outputs=["value_BSHDh"])
101-
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
102-
value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3])
80+
if not self._is_cross_attention:
81+
# Reshape from (B, S, D) to (B, S, H, D/H)
82+
key_BSHDh = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"])
83+
84+
# Possible Transpose patterns for key:
85+
# This scenario optimizes the need for a double transpose
86+
# 1. (B, S, H, D/H) -> (B, H, D/H, S)
87+
# Patterns with double transpose of key
88+
# Double transpose should handle this optimization
89+
# 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
90+
# Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
91+
# 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
92+
key_BHSDh = op.Transpose(key_BSHDh, perm=key_perm)
93+
94+
# Reshape from (B, S, D) to (B, S, H, D/H)
95+
value_BSHDh = op.Reshape(value, pattern.ANY_VALUE, _outputs=["value_BSHDh"])
96+
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
97+
value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3])
98+
else:
99+
# For cross-attention, key and value are not reshaped
100+
key_BHSDh = key
101+
value_BHSDh = value
103102

104103
if self._is_rotary:
105104
# This is workaround for examples where there is a duplication of Unsqueeze op
@@ -117,9 +116,12 @@ def pattern(
117116
query_BHSDh_emb = op.RotaryEmbedding(
118117
query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft"
119118
)
120-
key_BHSDh_emb = op.RotaryEmbedding(
121-
key_BHSDh, position_ids_k, cos, sin, _domain="com.microsoft"
122-
)
119+
if not self._is_cross_attention:
120+
key_BHSDh_emb = op.RotaryEmbedding(
121+
key_BHSDh, position_ids_k, cos, sin, _domain="com.microsoft"
122+
)
123+
else:
124+
key_BHSDh_emb = key_BHSDh
123125
else:
124126
# If rotary embedding is not used, we fuse with positional_embeddings
125127
query_BHSDh_emb = query_BHSDh
@@ -130,23 +132,13 @@ def pattern(
130132
if self._has_past_present:
131133
key_seq = op.Concat(past_key, key_BHSDh_emb, axis=-2)
132134
else:
133-
# For patterns where cross-attention key/value originates from past_key/past_value
134-
if self._is_cross_attention_from_past:
135-
key_seq = past_key
136-
self._k_from_past = key_seq
137-
else:
138-
key_seq = key_BHSDh_emb
135+
key_seq = key_BHSDh_emb
139136

140137
# Concatenate past_value cache and current value
141138
if self._has_past_present:
142139
value_seq = op.Concat(past_value, value_BHSDh, axis=-2)
143140
else:
144-
# For patterns where cross-attention key/value originates from past_key/past_value
145-
if self._is_cross_attention_from_past:
146-
value_seq = past_value
147-
self._v_from_past = value_seq
148-
else:
149-
value_seq = value_BHSDh
141+
value_seq = value_BHSDh
150142

151143
# Key/value to be used for dot-product attention computation
152144
key_seq_to_sdpa = key_seq
@@ -198,8 +190,8 @@ def check(
198190
self,
199191
op,
200192
query_BSD,
201-
key_BSD,
202-
value_BSD,
193+
key,
194+
value,
203195
mask,
204196
past_key,
205197
past_value,
@@ -221,97 +213,57 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
221213
f"Shape mismatch: {query_BSD} does not match expected dimensions ['B', 'S', 'D']",
222214
query_BSD,
223215
)
224-
# If cross-attention key/value originates from past_key/past_value,
225-
# Check if their producer is None, this is done to avoid from the matcher assuming
226-
# that if a key/value pattern path does not exist, it is a cross-attention pattern.
227-
if self._is_cross_attention_from_past:
228-
if self._k_from_past is not None:
229-
if self._k_from_past.producer() is not None:
230-
return check_result.fail(
231-
"Key is not from past_key/past_value. This is not a cross-attention pattern.",
232-
)
233-
if self._v_from_past is not None:
234-
if self._v_from_past.producer() is not None:
235-
return check_result.fail(
236-
"Value is not from past_key/past_value. This is not a cross-attention pattern.",
237-
)
238-
# We only consider patterns where,
239-
# 1) double_transpose = True, to avoid pattern consuming the transpose of key.
240-
# 2) is_rotary = False, as if rotary embeddings are used, the key is not from past_key.
241-
# TODO: Determine what parameter conditions would make this pattern full-proof.
242-
if not self._double_transpose or self._is_rotary:
243-
return check_result.fail(
244-
"Key is not from past_key/past_value. This is not a cross-attention pattern.",
245-
)
246216

247-
"""
248-
# Check for key transpose values
249-
k_perm = _ir_utils.get_singleton_value(key_perm)
250-
if k_perm is None or not isinstance(k_perm, list):
251-
return check_result.fail(
252-
f"Key permutation is not a list.",
253-
key_perm,
254-
)
255-
if len(k_perm) != 4:
217+
if no_match(query_BSHDh, ["B", "S", "H", "Dh"]):
256218
return check_result.fail(
257-
f"Key permutation is not of length 4.",
258-
key_perm,
219+
f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
220+
query_BSHDh,
259221
)
260-
if self._double_transpose:
261-
if k_perm != [0, 2, 1, 3]:
222+
# If cross-attention key/value shapes are 4D
223+
if self._is_cross_attention:
224+
if no_match(key, ["B", "H", "Skv", "Dh"]):
262225
return check_result.fail(
263-
f"Key permutation is not [0, 2, 1, 3].",
264-
key_perm,
226+
f"Shape mismatch: {key} does not match expected dimensions ['B', 'H', 'Skv', 'Dh']",
227+
key,
265228
)
266-
else:
267-
if k_perm != [0, 2, 3, 1]:
229+
if no_match(value, ["B", "H", "Skv", "Dv"]):
268230
return check_result.fail(
269-
f"Key permutation is not [0, 2, 3, 1].",
270-
key_perm,
231+
f"Shape mismatch: {value} does not match expected dimensions ['B', 'H', 'Skv', 'Dv']",
232+
value,
271233
)
272-
"""
273-
274-
if not self._is_cross_attention_from_past:
275-
if no_match(key_BSD, ["B", "Skv", "D"]):
234+
# Ensure that no past_key/past_value is used in cross-attention
235+
if past_key is not None:
276236
return check_result.fail(
277-
f"Shape mismatch: {key_BSD} does not match expected dimensions ['B', 'Skv', 'D']",
278-
query_BSD,
279-
)
280-
if no_match(value_BSD, ["B", "Skv", "D"]):
281-
return check_result.fail(
282-
f"Shape mismatch: {value_BSD} does not match expected dimensions ['B', 'Skv', 'D']",
283-
value_BSD,
284-
)
285-
286-
if self._has_past_present:
287-
if no_match(past_key, ["B", "H", "Spast", "Dh"]):
288-
return check_result.fail(
289-
f"Shape mismatch: {past_key} does not match expected dimensions ['B', 'H', 'Spast', 'Dh']",
237+
"past_key should be None in cross-attention.",
290238
past_key,
291239
)
292-
if no_match(past_value, ["B", "H", "Spast", "Dv"]):
240+
if past_value is not None:
293241
return check_result.fail(
294-
f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']",
242+
"past_value should be None in cross-attention.",
295243
past_value,
296244
)
297-
298-
if no_match(query_BSHDh, ["B", "S", "H", "Dh"]):
299-
return check_result.fail(
300-
f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
301-
query_BSHDh,
302-
)
303-
304-
if not self._is_cross_attention_from_past:
305-
if key_BSHDh and no_match(key_BSHDh, ["B", "S", "H", "Dh"]):
245+
else:
246+
if no_match(key, ["B", "Skv", "D"]):
306247
return check_result.fail(
307-
f"Shape mismatch: {key_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
308-
query_BSHDh,
248+
f"Shape mismatch: {key} does not match expected dimensions ['B', 'Skv', 'D']",
249+
query_BSD,
309250
)
310-
if value_BSHDh and no_match(value_BSHDh, ["B", "S", "H", "Dh"]):
251+
if no_match(value, ["B", "Skv", "D"]):
311252
return check_result.fail(
312-
f"Shape mismatch: {value_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
313-
query_BSHDh,
253+
f"Shape mismatch: {value} does not match expected dimensions ['B', 'Skv', 'D']",
254+
value,
314255
)
256+
if self._has_past_present:
257+
if no_match(past_key, ["B", "H", "Spast", "Dh"]):
258+
return check_result.fail(
259+
f"Shape mismatch: {past_key} does not match expected dimensions ['B', 'H', 'Spast', 'Dh']",
260+
past_key,
261+
)
262+
if no_match(past_value, ["B", "H", "Spast", "Dv"]):
263+
return check_result.fail(
264+
f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']",
265+
past_value,
266+
)
315267

316268
# TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St)
317269
# But this also, unforunately, depends on ORT version.
@@ -326,8 +278,8 @@ def rewrite(
326278
self,
327279
op,
328280
query_BSD,
329-
key_BSD,
330-
value_BSD,
281+
key,
282+
value,
331283
mask,
332284
past_key,
333285
past_value,
@@ -353,35 +305,21 @@ def rewrite(
353305
query_BSD_emb = op.RotaryEmbedding(
354306
query_BSD, position_ids, cos, sin, _domain="com.microsoft"
355307
)
356-
key_BSD_emb = op.RotaryEmbedding(
357-
key_BSD, position_ids, cos, sin, _domain="com.microsoft"
358-
)
308+
if not self._is_cross_attention:
309+
key_BSD_emb = op.RotaryEmbedding(
310+
key, position_ids, cos, sin, _domain="com.microsoft"
311+
)
312+
else:
313+
key_BSD_emb = key
359314
else:
360315
query_BSD_emb = query_BSD
361-
key_BSD_emb = key_BSD
316+
key_BSD_emb = key
362317

363318
num_outputs = 1 + (2 * self._has_past_present)
364-
# Special case for cross-attention that comes from past_key/past_value
365-
if self._is_cross_attention_from_past:
366-
return op.MultiHeadAttention(
367-
query_BSD_emb,
368-
past_key,
369-
past_value,
370-
None, # bias
371-
None, # key padding mask
372-
mask, # attention mask/bias
373-
None,
374-
None,
375-
num_heads=num_heads,
376-
scale=scale,
377-
_domain="com.microsoft",
378-
_outputs=num_outputs,
379-
)
380-
381319
return op.MultiHeadAttention(
382320
query_BSD_emb,
383321
key_BSD_emb,
384-
value_BSD,
322+
value,
385323
None, # bias
386324
None, # key padding mask
387325
mask, # attention mask/bias
@@ -402,7 +340,7 @@ def rewrite(
402340
"is_rotary": is_rotary,
403341
"use_mask": use_mask,
404342
"has_past_present": has_past_present,
405-
"is_cross_attention_from_past": is_cross_attention_from_past,
343+
"is_cross_attention": is_cross_attention,
406344
}
407345
for double_transpose in [False, True]
408346
for transpose_4d in (
@@ -411,9 +349,9 @@ def rewrite(
411349
for pre_scale_q in [True, False]
412350
for is_rotary in [False, True]
413351
for use_mask in [False, True]
414-
# TODO: Avoid this parameter from being order dependent
352+
# Enforce has_past_present to be True first, to avoid missing the pattern
415353
for has_past_present in [True, False]
416-
for is_cross_attention_from_past in [False, True]
354+
for is_cross_attention in [False, True]
417355
]
418356

419357
# Dynamically create the rules
@@ -426,7 +364,7 @@ def rewrite(
426364
f"{'_Rotary' if params['is_rotary'] else ''}"
427365
f"{'_Masked' if params['use_mask'] else ''}"
428366
f"{'_Past' if params['has_past_present'] else ''}"
429-
f"{'_CrossAttentionFromPast' if params['is_cross_attention_from_past'] else ''}",
367+
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",
430368
**params,
431369
)
432370
for params in parameter_combinations

0 commit comments

Comments
 (0)