Skip to content

Commit b76e1b3

Browse files
authored
Fixes to MHA fusion (#2380)
A couple of cleanup/fixes to MHA fusion: * Add a pattern to handle one transpose pattern (needed for codellama) * Simplify the handling of optional mask Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent d7974ba commit b76e1b3

File tree

1 file changed

+31
-27
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+31
-27
lines changed

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(
4040
transpose_4d: bool,
4141
pre_scale_q: bool,
4242
is_rotary: bool,
43-
use_mask: bool,
4443
has_past_present: bool,
4544
is_cross_attention: bool,
4645
):
@@ -49,7 +48,6 @@ def __init__(
4948
self._transpose_4d = transpose_4d
5049
self._pre_scale_q = pre_scale_q
5150
self._is_rotary = is_rotary
52-
self._use_mask = use_mask
5351
self._has_past_present = has_past_present
5452
self._is_cross_attention = is_cross_attention
5553

@@ -59,13 +57,11 @@ def pattern(
5957
query_BSD,
6058
key,
6159
value,
62-
mask,
6360
past_key,
6461
past_value,
6562
position_ids,
6663
cos,
6764
sin,
68-
key_perm,
6965
q_scale,
7066
):
7167
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
@@ -80,6 +76,12 @@ def pattern(
8076
if not self._is_cross_attention:
8177
# Reshape from (B, S, D) to (B, S, H, D/H)
8278
key = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"])
79+
# Key may or may not be transposed at this point, based on usage pattern
80+
key = pattern.OrValue(
81+
[op.Transpose(key, perm=[0, 2, 1, 3]), key],
82+
tag_var="key_transposed",
83+
tag_values=[True, False],
84+
)
8385

8486
# Reshape from (B, S, D) to (B, S, H, D/H)
8587
value_BSHDh = op.Reshape(value, pattern.ANY_VALUE, _outputs=["value_BSHDh"])
@@ -133,22 +135,14 @@ def pattern(
133135
key_seq_to_sdpa = key_seq
134136
value_seq_to_sdpa = value_seq
135137

136-
# TODO: Remove use_mask once SDPA op is usable
137-
if self._use_mask:
138-
sdpa = op.SDPA(
139-
query_BHSDh_emb,
140-
key_seq_to_sdpa,
141-
value_seq_to_sdpa,
142-
mask,
143-
_domain="ai.onnxruntime._fusion",
144-
)
145-
else:
146-
sdpa = op.SDPA(
147-
query_BHSDh_emb,
148-
key_seq_to_sdpa,
149-
value_seq_to_sdpa,
150-
_domain="ai.onnxruntime._fusion",
151-
)
138+
sdpa = op.SDPA(
139+
query_BHSDh_emb,
140+
key_seq_to_sdpa,
141+
value_seq_to_sdpa,
142+
_allow_other_inputs=True,
143+
_outputs=["sdpa_output"],
144+
_domain="ai.onnxruntime._fusion",
145+
)
152146

153147
# Transpose attention back to (B, S, H, D/H)
154148
attention_transposed = op.Transpose(sdpa, perm=[0, 2, 1, 3])
@@ -167,17 +161,19 @@ def check(
167161
query_BSD,
168162
key,
169163
value,
170-
mask,
164+
sdpa_output,
171165
past_key,
172166
past_value,
173-
key_perm,
174167
query_BSHDh,
168+
key_transposed=None,
175169
key_BSHDh=None,
176170
value_BSHDh=None,
177171
**_,
178172
) -> pattern.MatchResult: # type: ignore[name-defined]
179173
check_result = pattern.MatchResult()
180174

175+
sdpa_node = sdpa_output.producer()
176+
181177
bindings: dict[str, Dim] = {}
182178

183179
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
@@ -223,6 +219,13 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
223219
f"Shape mismatch: {key} does not match expected dimensions ['B', 'Skv', 'D']",
224220
query_BSD,
225221
)
222+
sdpa_key_format = sdpa_node.attributes.get_string("key_format")
223+
expected_key_format = "BHSd" if key_transposed else "BSHd"
224+
if sdpa_key_format != expected_key_format:
225+
return check_result.fail(
226+
f"Unexpected key format: {sdpa_key_format}. Expected: {expected_key_format}",
227+
sdpa_node,
228+
)
226229
if no_match(value, ["B", "Skv", "D"]):
227230
return check_result.fail(
228231
f"Shape mismatch: {value} does not match expected dimensions ['B', 'Skv', 'D']",
@@ -245,7 +248,11 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
245248
# ORT's contrib ops (MHA, Attention) allow a mask of shape (1 or B, 1 or H, S, St)
246249
# That is: broadcast allowed only for the first two dimensions. (Even that is not
247250
# supported by some earlier versions of ORT, which are not supported here.)
248-
if self._use_mask:
251+
mask = None
252+
if len(sdpa_node.inputs) > 3:
253+
mask = sdpa_node.inputs[3]
254+
self.mask = mask
255+
if mask is not None:
249256
if (mask_shape := mask.shape) is None:
250257
return check_result.fail(
251258
"Mask shape cannot be determined.",
@@ -293,7 +300,6 @@ def rewrite(
293300
query_BSD,
294301
key,
295302
value,
296-
mask,
297303
past_key,
298304
past_value,
299305
query_BSHDh,
@@ -335,6 +341,7 @@ def rewrite(
335341
query_BSD_emb = query_BSD
336342
key_BSD_emb = key
337343

344+
mask = self.mask
338345
if self._use_mask_broadcast:
339346
one = op.Constant(value_ints=[1])
340347
S = op.Shape(query_BSD, start=1, end=2)
@@ -365,7 +372,6 @@ def _make_rule_set(has_past_present: bool):
365372
"transpose_4d": transpose_4d,
366373
"pre_scale_q": pre_scale_q,
367374
"is_rotary": is_rotary,
368-
"use_mask": use_mask,
369375
"has_past_present": has_past_present,
370376
"is_cross_attention": is_cross_attention,
371377
}
@@ -375,7 +381,6 @@ def _make_rule_set(has_past_present: bool):
375381
) # Only generate patterns when double_transpose is True
376382
for pre_scale_q in [True, False]
377383
for is_rotary in [False, True]
378-
for use_mask in [False, True]
379384
for is_cross_attention in ([False] if has_past_present else [False, True])
380385
]
381386

@@ -387,7 +392,6 @@ def _make_rule_set(has_past_present: bool):
387392
f"{'_Twice' if params['double_transpose'] else ''}"
388393
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
389394
f"{'_Rotary' if params['is_rotary'] else ''}"
390-
f"{'_Masked' if params['use_mask'] else ''}"
391395
f"{'_Past' if params['has_past_present'] else ''}"
392396
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",
393397
**params,

0 commit comments

Comments
 (0)