Skip to content

Fixes to MHA fusion #2380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions onnxscript/rewriter/ort_fusions/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
transpose_4d: bool,
pre_scale_q: bool,
is_rotary: bool,
use_mask: bool,
has_past_present: bool,
is_cross_attention: bool,
):
Expand All @@ -49,7 +48,6 @@ def __init__(
self._transpose_4d = transpose_4d
self._pre_scale_q = pre_scale_q
self._is_rotary = is_rotary
self._use_mask = use_mask
self._has_past_present = has_past_present
self._is_cross_attention = is_cross_attention

Expand All @@ -59,13 +57,11 @@ def pattern(
query_BSD,
key,
value,
mask,
past_key,
past_value,
position_ids,
cos,
sin,
key_perm,
q_scale,
):
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)
Expand All @@ -80,6 +76,12 @@ def pattern(
if not self._is_cross_attention:
# Reshape from (B, S, D) to (B, S, H, D/H)
key = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"])
# Key may or may not be transposed at this point, based on usage pattern
key = pattern.OrValue(
[op.Transpose(key, perm=[0, 2, 1, 3]), key],
tag_var="key_transposed",
tag_values=[True, False],
)

# Reshape from (B, S, D) to (B, S, H, D/H)
value_BSHDh = op.Reshape(value, pattern.ANY_VALUE, _outputs=["value_BSHDh"])
Expand Down Expand Up @@ -133,22 +135,14 @@ def pattern(
key_seq_to_sdpa = key_seq
value_seq_to_sdpa = value_seq

# TODO: Remove use_mask once SDPA op is usable
if self._use_mask:
sdpa = op.SDPA(
query_BHSDh_emb,
key_seq_to_sdpa,
value_seq_to_sdpa,
mask,
_domain="ai.onnxruntime._fusion",
)
else:
sdpa = op.SDPA(
query_BHSDh_emb,
key_seq_to_sdpa,
value_seq_to_sdpa,
_domain="ai.onnxruntime._fusion",
)
sdpa = op.SDPA(
query_BHSDh_emb,
key_seq_to_sdpa,
value_seq_to_sdpa,
_allow_other_inputs=True,
_outputs=["sdpa_output"],
_domain="ai.onnxruntime._fusion",
)

# Transpose attention back to (B, S, H, D/H)
attention_transposed = op.Transpose(sdpa, perm=[0, 2, 1, 3])
Expand All @@ -167,17 +161,19 @@ def check(
query_BSD,
key,
value,
mask,
sdpa_output,
past_key,
past_value,
key_perm,
query_BSHDh,
key_transposed=None,
key_BSHDh=None,
value_BSHDh=None,
**_,
) -> pattern.MatchResult: # type: ignore[name-defined]
check_result = pattern.MatchResult()

sdpa_node = sdpa_output.producer()

bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
Expand Down Expand Up @@ -223,6 +219,13 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
f"Shape mismatch: {key} does not match expected dimensions ['B', 'Skv', 'D']",
query_BSD,
)
sdpa_key_format = sdpa_node.attributes.get_string("key_format")
expected_key_format = "BHSd" if key_transposed else "BSHd"
if sdpa_key_format != expected_key_format:
return check_result.fail(
f"Unexpected key format: {sdpa_key_format}. Expected: {expected_key_format}",
sdpa_node,
)
if no_match(value, ["B", "Skv", "D"]):
return check_result.fail(
f"Shape mismatch: {value} does not match expected dimensions ['B', 'Skv', 'D']",
Expand All @@ -245,7 +248,11 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
# ORT's contrib ops (MHA, Attention) allow a mask of shape (1 or B, 1 or H, S, St)
# That is: broadcast allowed only for the first two dimensions. (Even that is not
# supported by some earlier versions of ORT, which are not supported here.)
if self._use_mask:
mask = None
if len(sdpa_node.inputs) > 3:
mask = sdpa_node.inputs[3]
self.mask = mask
if mask is not None:
if (mask_shape := mask.shape) is None:
return check_result.fail(
"Mask shape cannot be determined.",
Expand Down Expand Up @@ -293,7 +300,6 @@ def rewrite(
query_BSD,
key,
value,
mask,
past_key,
past_value,
query_BSHDh,
Expand Down Expand Up @@ -335,6 +341,7 @@ def rewrite(
query_BSD_emb = query_BSD
key_BSD_emb = key

mask = self.mask
if self._use_mask_broadcast:
one = op.Constant(value_ints=[1])
S = op.Shape(query_BSD, start=1, end=2)
Expand Down Expand Up @@ -365,7 +372,6 @@ def _make_rule_set(has_past_present: bool):
"transpose_4d": transpose_4d,
"pre_scale_q": pre_scale_q,
"is_rotary": is_rotary,
"use_mask": use_mask,
"has_past_present": has_past_present,
"is_cross_attention": is_cross_attention,
}
Expand All @@ -375,7 +381,6 @@ def _make_rule_set(has_past_present: bool):
) # Only generate patterns when double_transpose is True
for pre_scale_q in [True, False]
for is_rotary in [False, True]
for use_mask in [False, True]
for is_cross_attention in ([False] if has_past_present else [False, True])
]

Expand All @@ -387,7 +392,6 @@ def _make_rule_set(has_past_present: bool):
f"{'_Twice' if params['double_transpose'] else ''}"
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
f"{'_Rotary' if params['is_rotary'] else ''}"
f"{'_Masked' if params['use_mask'] else ''}"
f"{'_Past' if params['has_past_present'] else ''}"
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",
**params,
Expand Down
Loading