Skip to content

Commit 9d04864

Browse files
rewritten mha fusion
1 parent 69ab1ab commit 9d04864

File tree

2 files changed

+112
-44
lines changed

2 files changed

+112
-44
lines changed

onnxscript/rewriter/ort_fusions/fuse_mha_bias.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def pattern(
4141
past_key,
4242
past_value,
4343
num_heads,
44-
scale,
44+
# scale,
4545
):
4646
if not self._q_no_bias:
4747
query_BSD = op.Add(query_matmul, q_bias)
@@ -66,7 +66,7 @@ def pattern(
6666
past_key,
6767
past_value,
6868
num_heads=num_heads,
69-
scale=scale,
69+
# scale=scale,
7070
_domain="com.microsoft",
7171
)
7272

@@ -136,7 +136,7 @@ def rewrite(
136136
past_key,
137137
past_value,
138138
num_heads,
139-
scale,
139+
# scale,
140140
**_,
141141
):
142142
if self._q_no_bias:
@@ -162,7 +162,7 @@ def rewrite(
162162
past_key,
163163
past_value,
164164
num_heads=num_heads,
165-
scale=scale,
165+
# scale=scale,
166166
_domain="com.microsoft",
167167
)
168168

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 108 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,28 @@ def __init__(
3636
self,
3737
name,
3838
*,
39+
double_transpose: bool,
3940
transpose_4d: bool,
4041
pre_scale_q: bool,
4142
is_rotary: bool,
4243
use_mask: bool,
4344
has_past_present: bool,
44-
is_cross_attention: bool,
45+
is_cross_attention_from_past: bool,
4546
):
4647
super().__init__(name)
48+
self._double_transpose = double_transpose
4749
self._transpose_4d = transpose_4d
4850
self._pre_scale_q = pre_scale_q
4951
self._is_rotary = is_rotary
5052
self._use_mask = use_mask
5153
self._has_past_present = has_past_present
52-
# Currently, we only support cross-attention when cross
54+
# Checks for cross-attention pattern when cross
5355
# query and key originate from past_key and past_value.
54-
# TODO: Support patterns where any key/value can be used for cross-attention.
55-
self._is_cross_attention = is_cross_attention
56+
self._is_cross_attention_from_past = is_cross_attention_from_past
57+
# Store the key/value to check if the cross-attention is
58+
# indeed from past_key and past_value.
59+
self._k_from_past = None
60+
self._v_from_past = None
5661

5762
def pattern(
5863
self,
@@ -66,6 +71,7 @@ def pattern(
6671
position_ids,
6772
cos,
6873
sin,
74+
key_perm,
6975
q_scale,
7076
):
7177
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
@@ -80,12 +86,15 @@ def pattern(
8086
# Reshape from (B, S, D) to (B, S, H, D/H)
8187
key_BSHDh = op.Reshape(key_BSD, pattern.ANY_VALUE, _outputs=["key_BSHDh"])
8288

83-
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
84-
# TODO: Fix condition
85-
if not self._is_cross_attention and self._has_past_present:
86-
key_BHSDh = op.Transpose(key_BSHDh, perm=[0, 2, 1, 3])
87-
else:
88-
key_BHSDh = op.Transpose(key_BSHDh, perm=[0, 2, 3, 1])
89+
# Possible Transpose patterns for key:
90+
# This scenario optimizes the need for a double transpose
91+
# 1. (B, S, H, D/H) -> (B, H, D/H, S)
92+
# Patterns with double transpose of key
93+
# Double transpose should handle this optimization
94+
# 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
95+
# Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
96+
# 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
97+
key_BHSDh = op.Transpose(key_BSHDh, perm=key_perm)
8998

9099
# Reshape from (B, S, D) to (B, S, H, D/H)
91100
value_BSHDh = op.Reshape(value_BSD, pattern.ANY_VALUE, _outputs=["value_BSHDh"])
@@ -119,41 +128,43 @@ def pattern(
119128
# Concatenate past_key cache and current key, and transpose to enable
120129
# dot-product attention computation.
121130
if self._has_past_present:
131+
key_seq = op.Concat(past_key, key_BHSDh_emb, axis=-2)
132+
else:
122133
# For patterns where cross-attention key/value originates from past_key/past_value
123-
if self._is_cross_attention:
134+
if self._is_cross_attention_from_past:
124135
key_seq = past_key
136+
self._k_from_past = key_seq
125137
else:
126-
key_seq = op.Concat(past_key, key_BHSDh_emb, axis=-2)
127-
else:
128-
key_seq = key_BHSDh_emb
138+
key_seq = key_BHSDh_emb
129139

130140
# Concatenate past_value cache and current value
131141
if self._has_past_present:
142+
value_seq = op.Concat(past_value, value_BHSDh, axis=-2)
143+
else:
132144
# For patterns where cross-attention key/value originates from past_key/past_value
133-
if self._is_cross_attention:
145+
if self._is_cross_attention_from_past:
134146
value_seq = past_value
147+
self._v_from_past = value_seq
135148
else:
136-
value_seq = op.Concat(past_value, value_BHSDh, axis=-2)
137-
else:
138-
value_seq = value_BHSDh
149+
value_seq = value_BHSDh
139150

140151
# Key/value to be used for dot-product attention computation
141152
key_seq_to_sdpa = key_seq
142153
value_seq_to_sdpa = value_seq
143154

144155
# Transpose last two axes of key_seq to compute dot-product via matmul.
145-
if self._transpose_4d:
146-
if self._has_past_present:
156+
if self._double_transpose:
157+
if self._transpose_4d:
147158
key_seq_to_sdpa = op.Transpose(key_seq_to_sdpa, perm=[0, 1, 3, 2])
148-
else:
149-
# Transpose after converting to 3D
150-
key_seq_BH_Skv_Dh = op.Reshape(
151-
key_seq_to_sdpa, pattern.ANY_VALUE, _outputs=["key_seq_BH_Skv_Dh"]
152-
)
153-
key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1])
154-
key_seq_to_sdpa = op.Reshape(
155-
key_seq_BH_Dh_Skv, pattern.ANY_VALUE, _outputs=["key_seq_B_H_Dh_Skv"]
156-
)
159+
else:
160+
# Transpose after converting to 3D
161+
key_seq_BH_Skv_Dh = op.Reshape(
162+
key_seq_to_sdpa, pattern.ANY_VALUE, _outputs=["key_seq_BH_Skv_Dh"]
163+
)
164+
key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1])
165+
key_seq_to_sdpa = op.Reshape(
166+
key_seq_BH_Dh_Skv, pattern.ANY_VALUE, _outputs=["key_seq_B_H_Dh_Skv"]
167+
)
157168

158169
# TODO: Remove use_mask once SDPA op is usable
159170
if self._use_mask:
@@ -178,7 +189,7 @@ def pattern(
178189
attention = op.Reshape(
179190
attention_transposed, pattern.ANY_VALUE, _outputs=["attention_reshaped"]
180191
)
181-
if self._has_past_present and not self._is_cross_attention:
192+
if self._has_past_present:
182193
return attention, key_seq, value_seq
183194
else:
184195
return attention
@@ -192,6 +203,7 @@ def check(
192203
mask,
193204
past_key,
194205
past_value,
206+
key_perm,
195207
query_BSHDh,
196208
key_BSHDh=None,
197209
value_BSHDh=None,
@@ -209,7 +221,57 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
209221
f"Shape mismatch: {query_BSD} does not match expected dimensions ['B', 'S', 'D']",
210222
query_BSD,
211223
)
212-
if not self._is_cross_attention:
224+
# If cross-attention key/value originates from past_key/past_value,
225+
# Check if their producer is None, this is done to avoid from the matcher assuming
226+
# that if a key/value pattern path does not exist, it is a cross-attention pattern.
227+
if self._is_cross_attention_from_past:
228+
if self._k_from_past is not None:
229+
if self._k_from_past.producer() is not None:
230+
return check_result.fail(
231+
"Key is not from past_key/past_value. This is not a cross-attention pattern.",
232+
)
233+
if self._v_from_past is not None:
234+
if self._v_from_past.producer() is not None:
235+
return check_result.fail(
236+
"Value is not from past_key/past_value. This is not a cross-attention pattern.",
237+
)
238+
# We only consider patterns where,
239+
# 1) double_transpose = True, to avoid pattern consuming the transpose of key.
240+
# 2) is_rotary = False, as if rotary embeddings are used, the key is not from past_key.
241+
# TODO: Determine what parameter conditions would make this pattern full-proof.
242+
if not self._double_transpose or self._is_rotary:
243+
return check_result.fail(
244+
"Key is not from past_key/past_value. This is not a cross-attention pattern.",
245+
)
246+
247+
"""
248+
# Check for key transpose values
249+
k_perm = _ir_utils.get_singleton_value(key_perm)
250+
if k_perm is None or not isinstance(k_perm, list):
251+
return check_result.fail(
252+
f"Key permutation is not a list.",
253+
key_perm,
254+
)
255+
if len(k_perm) != 4:
256+
return check_result.fail(
257+
f"Key permutation is not of length 4.",
258+
key_perm,
259+
)
260+
if self._double_transpose:
261+
if k_perm != [0, 2, 1, 3]:
262+
return check_result.fail(
263+
f"Key permutation is not [0, 2, 1, 3].",
264+
key_perm,
265+
)
266+
else:
267+
if k_perm != [0, 2, 3, 1]:
268+
return check_result.fail(
269+
f"Key permutation is not [0, 2, 3, 1].",
270+
key_perm,
271+
)
272+
"""
273+
274+
if not self._is_cross_attention_from_past:
213275
if no_match(key_BSD, ["B", "Skv", "D"]):
214276
return check_result.fail(
215277
f"Shape mismatch: {key_BSD} does not match expected dimensions ['B', 'Skv', 'D']",
@@ -239,7 +301,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
239301
query_BSHDh,
240302
)
241303

242-
if not self._is_cross_attention:
304+
if not self._is_cross_attention_from_past:
243305
if key_BSHDh and no_match(key_BSHDh, ["B", "S", "H", "Dh"]):
244306
return check_result.fail(
245307
f"Shape mismatch: {key_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']",
@@ -298,9 +360,9 @@ def rewrite(
298360
query_BSD_emb = query_BSD
299361
key_BSD_emb = key_BSD
300362

301-
num_outputs = 1 + (2 * self._has_past_present * (not self._is_cross_attention))
302-
# Special case for cross-attention
303-
if self._has_past_present and self._is_cross_attention:
363+
num_outputs = 1 + (2 * self._has_past_present)
364+
# Special case for cross-attention that comes from past_key/past_value
365+
if self._is_cross_attention_from_past:
304366
return op.MultiHeadAttention(
305367
query_BSD_emb,
306368
past_key,
@@ -334,31 +396,37 @@ def rewrite(
334396

335397
parameter_combinations = [
336398
{
399+
"double_transpose": double_transpose,
337400
"transpose_4d": transpose_4d,
338401
"pre_scale_q": pre_scale_q,
339402
"is_rotary": is_rotary,
340403
"use_mask": use_mask,
341404
"has_past_present": has_past_present,
342-
"is_cross_attention": is_cross_attention,
405+
"is_cross_attention_from_past": is_cross_attention_from_past,
343406
}
344-
for transpose_4d in [False, True]
407+
for double_transpose in [False, True]
408+
for transpose_4d in (
409+
[False, True] if double_transpose else [False]
410+
) # Only generate patterns when double_transpose is True
345411
for pre_scale_q in [True, False]
346412
for is_rotary in [False, True]
347413
for use_mask in [False, True]
348-
for has_past_present in [False, True]
349-
for is_cross_attention in [False, True]
414+
# TODO: Avoid this parameter from being order dependent
415+
for has_past_present in [True, False]
416+
for is_cross_attention_from_past in [False, True]
350417
]
351418

352419
# Dynamically create the rules
353420
mha_rules = pattern.RewriteRuleSet(
354421
[
355422
MultiHeadAttention.rule(
356423
f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose"
424+
f"{'_Twice' if params['double_transpose'] else ''}"
357425
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
358426
f"{'_Rotary' if params['is_rotary'] else ''}"
359427
f"{'_Masked' if params['use_mask'] else ''}"
360428
f"{'_Past' if params['has_past_present'] else ''}"
361-
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",
429+
f"{'_CrossAttentionFromPast' if params['is_cross_attention_from_past'] else ''}",
362430
**params,
363431
)
364432
for params in parameter_combinations

0 commit comments

Comments
 (0)