Skip to content

Fusion extensions to improve GQA fusion #2374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,18 @@ def fuse(func, **kwargs):
# We apply shape inference after the SDPA fusion as new nodes are added
# in the rewrite rule for certain patterns of SDPA.
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
# Optimize to avoid trying multiple attention-based fusions

fusion_count["gqa"] = fuse(fuse_gqa)
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)

fusion_count["mha1"] = fuse(fuse_mha1)
fusion_count["mha2"] = fuse(fuse_mha2)
if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0):
# If no MHA fusion was applied, we can try the GQA fusion.
# and avoid trying the attention fusion.
fusion_count["gqa"] = fuse(fuse_gqa)
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
fusion_count["mha_bias"] = 0
fusion_count["attention"] = 0
else:
fusion_count["mha_bias"] = fuse(fuse_mha_bias)
fusion_count["attention"] = fuse(fuse_attention)
fusion_count["gqa"] = 0
fusion_count["gelu"] = fuse(fuse_gelu)
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
# Finally: inline any intermediate fusion functions introduced that were not
Expand Down
26 changes: 23 additions & 3 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,16 @@
self._inv_freq_cos_sin_cache.clear()

def pattern(
self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, extra_dims
self,
op,
x,
inv_freq,
position_ids,
interleaved,
num_heads,
freqs,
dtype,
extra_dims,
):
if not self._const_freqs:
# Compute freqs from inv_freq and position_ids. In the _const_freqs case,
Expand All @@ -121,6 +130,13 @@
# if self._reshape:
# position_ids_expanded = op.Expand(position_ids_expanded, _allow_other_inputs=True)
# position_ids_expanded = op.Reshape(position_ids_expanded, _allow_other_inputs=True)
# inv_freq may optionally be expanded to shape [B, E, 1]
inv_freq = pattern.OrValue(
[
op.Expand(inv_freq, pattern.ANY_VALUE, _outputs=["expanded_inv_freq"]),
inv_freq,
]
)
freqs = op.MatMul(inv_freq, position_ids_expanded) # [B, E, S]
# if self._reshape:
# freqs = op.Reshape(freqs, freqs_3d_shape) # redundant reshape
Expand All @@ -140,11 +156,11 @@
sin_4d,
interleaved=interleaved,
num_heads=num_heads,
_domain="ai.onnxruntime.fusion",
_domain="ai.onnxruntime._fusion",
)

def check(
self, context, inv_freq, position_ids, freqs, extra_dims, **_
self, context, inv_freq, position_ids, freqs, extra_dims, expanded_inv_freq=None, **_
) -> pattern.MatchResult: # type: ignore[name-defined]
check_result = pattern.MatchResult()
# TODO(rama): handle redundant reshape/expand
Expand All @@ -164,6 +180,10 @@
if not _ir_utils.has_rank(inv_freq, 3):
return check_result.fail("inv_freq is not 3D.", inv_freq)
inv_freq_shape = inv_freq.shape
if expanded_inv_freq is not None:
if not _ir_utils.has_rank(expanded_inv_freq, 3):
return check_result.fail("expanded_inv_freq is not 3D.", expanded_inv_freq)

Check warning on line 185 in onnxscript/rewriter/ort_fusions/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/cos_sin_cache.py#L185

Added line #L185 was not covered by tests
# TODO: check expanded_inv_freq shape
if inv_freq.const_value is None: # TODO: should this be inv_freq_shape?
return check_result.fail("inv_freq is not a constant.", inv_freq)
if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1:
Expand Down
116 changes: 85 additions & 31 deletions onnxscript/rewriter/ort_fusions/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,11 @@ def pattern(
value_BSDkv,
past_key,
past_value,
input_ids,
past_seq_length,
total_seq_length,
position_ids_q,
position_ids_k,
cos,
sin,
some_kv_cache,
shape_B111,
mask,
):
# Reshape query from (B, S, D) to (B, S, H, D/H)
query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"])
Expand All @@ -101,10 +99,6 @@ def pattern(
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])

position_ids = op.Range(past_seq_length, total_seq_length, 1)
position_ids_q = op.Unsqueeze(position_ids, [0])
position_ids_k = op.Unsqueeze(position_ids, [0])

query_BHSDh_rope = op.RotaryEmbedding(
query_BHSDh,
position_ids_q,
Expand Down Expand Up @@ -141,15 +135,13 @@ def pattern(
value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"]
)

mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111)

key_seq_BHDhT = op.Transpose(key_seq_BHTDh, perm=[0, 1, 3, 2])
attention_BHSDh = op.SDPA(
query_BHSDh_rope,
key_seq_BHDhT,
key_seq_BHTDh,
value_seq_BHTDh,
mask,
_domain="ai.onnxruntime.fusion",
key_format="BHSd",
_domain="ai.onnxruntime._fusion",
)

# Transpose attention back to (B, S, H, D/H)
Expand Down Expand Up @@ -209,8 +201,8 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
# Rotary embedding attributes
query_rotary_attributes = query_BHSDh_rope.producer().attributes
key_rotary_attributes = key_BHkvSDh_rope.producer().attributes
query_interleaved = query_rotary_attributes.get("interleaved", 0)
key_interleaved = key_rotary_attributes.get("interleaved", 0)
query_interleaved = query_rotary_attributes.get_int("interleaved", 0)
key_interleaved = key_rotary_attributes.get_int("interleaved", 0)
if query_interleaved != key_interleaved:
return pattern.MatchResult().fail(
"Rotary embedding interleaved attribute mismatch",
Expand All @@ -228,42 +220,104 @@ def rewrite(
value_BSDkv,
past_key,
past_value,
total_seq_length,
position_ids_q,
position_ids_k,
cos,
sin,
mask,
**_,
):
total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32)
one_0D = op.Constant(value_int=1)
one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32)
seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32)
zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1])
seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D)

return op.GroupQueryAttention(
return op.GQA(
mask,
position_ids_k,
position_ids_q,
query_BSD,
key_BSDkv,
value_BSDkv,
past_key,
past_value,
seqlens_k,
total_seq_length_int32,
None, # seqlens_k,
None, # total_seq_length_int32,
cos,
sin,
# mask, # TODO: this is not a valid input for GQA
num_heads=self.num_heads,
kv_num_heads=self.kv_num_heads,
do_rotary=1,
rotary_interleaved=self._interleaved,
# skipped optional attributes: local_window_size, scale, smooth_softmax, softcap
_domain="com.microsoft",
_domain="ai.onnxruntime._fusion",
_outputs=3,
)


_rule1 = GroupQueryAttention.rule()
class GQACausalMask(pattern.RewriteRuleClassBase):
def __init__(self):
super().__init__("GQACausalMask", remove_nodes=False)

def pattern(
self,
op,
mask,
input_ids,
some_kv_cache,
shape_B111,
past_seq_length,
total_seq_length,
):
mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111)
position_ids = op.Range(past_seq_length, total_seq_length, 1)
position_ids_q = op.Unsqueeze(position_ids, [0])
position_ids_k = op.Unsqueeze(position_ids, [0])
return op.GQA(
mask,
position_ids_k,
position_ids_q,
_allow_other_inputs=True,
_domain="ai.onnxruntime._fusion",
_outputs=["attn_output", "key_seq", "value_seq"],
)

def rewrite(
self,
op,
total_seq_length,
attn_output,
**_,
):
# Construct total_seq_length_int32 and seqlens_k
total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32)
one_0D = op.Constant(value_int=1)
one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32)
seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32)
zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1])
seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D)

gqa_node = attn_output.producer()
assert len(gqa_node.inputs) == 12, (
f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}"
)
query, key, value, past_key, past_value = gqa_node.inputs[3:8]
cos, sin = gqa_node.inputs[10:12]
updated_inputs = [
query,
key,
value,
past_key,
past_value,
seqlens_k,
total_seq_length_int32,
cos,
sin,
]
attributes = gqa_node.attributes
return op.GroupQueryAttention(
*updated_inputs, **attributes, _domain="com.microsoft", _outputs=3
)


gqa_rules = pattern.RewriteRuleSet([_rule1])
_basic_gqa_rule = GroupQueryAttention.rule()
_gqa_causal_mask_rule = GQACausalMask.rule()

gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule])

fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)
10 changes: 8 additions & 2 deletions onnxscript/rewriter/ort_fusions/gqa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ def test_fusion(self):
onnx.TensorProto.FLOAT,
["B", self.seqlen, self.num_heads, self.head_size],
)
key_BHSDh_value_info = onnx.helper.make_tensor_value_info(
"key_BHSDh",
onnx.TensorProto.FLOAT,
["B", self.num_heads, self.total_seqlen, self.head_size],
)
key_BSHkvDh_value_info = onnx.helper.make_tensor_value_info(
"key_BSHkvDh",
onnx.TensorProto.FLOAT,
Expand All @@ -327,6 +332,7 @@ def test_fusion(self):
query_BHSDh_rope_value_info,
key_BHkvSDh_rope_value_info,
query_BSHDh_value_info,
key_BHSDh_value_info,
key_BSHkvDh_value_info,
key_transposed_value_info,
value_BHSDh_value_info,
Expand All @@ -338,10 +344,10 @@ def test_fusion(self):
onnxscript.optimizer.optimize(inferred_model)

count = fuse_sdpa(inferred_model, debug=True)
self.assertEqual(count, 1)
self.assertGreater(count, 0)

count = fuse_gqa(inferred_model, debug=True)
self.assertEqual(count, 1)
self.assertGreater(count, 0)

fused_model = ir.serde.to_proto(inferred_model)
session = ort.InferenceSession(
Expand Down
37 changes: 6 additions & 31 deletions onnxscript/rewriter/ort_fusions/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,14 @@ def pattern(

if not self._is_cross_attention:
# Reshape from (B, S, D) to (B, S, H, D/H)
key_BSHDh = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"])

# Possible Transpose patterns for key:
# This scenario optimizes the need for a double transpose
# 1. (B, S, H, D/H) -> (B, H, D/H, S)
# Patterns with double transpose of key
# Double transpose should handle this optimization
# 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
# Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
# 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
key_BHSDh = op.Transpose(key_BSHDh, perm=key_perm)
key = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"])

# Reshape from (B, S, D) to (B, S, H, D/H)
value_BSHDh = op.Reshape(value, pattern.ANY_VALUE, _outputs=["value_BSHDh"])
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3])
else:
# For cross-attention, key and value are not reshaped
key_BHSDh = key
value_BHSDh = value

if self._is_rotary:
Expand All @@ -118,14 +107,14 @@ def pattern(
)
if not self._is_cross_attention:
key_BHSDh_emb = op.RotaryEmbedding(
key_BHSDh, position_ids_k, cos, sin, _domain="com.microsoft"
key, position_ids_k, cos, sin, _domain="com.microsoft"
)
else:
key_BHSDh_emb = key_BHSDh
key_BHSDh_emb = key
else:
# If rotary embedding is not used, we fuse with positional_embeddings
query_BHSDh_emb = query_BHSDh
key_BHSDh_emb = key_BHSDh
key_BHSDh_emb = key

# Concatenate past_key cache and current key, and transpose to enable
# dot-product attention computation.
Expand All @@ -144,35 +133,21 @@ def pattern(
key_seq_to_sdpa = key_seq
value_seq_to_sdpa = value_seq

# Transpose last two axes of key_seq to compute dot-product via matmul.
if self._double_transpose:
if self._transpose_4d:
key_seq_to_sdpa = op.Transpose(key_seq_to_sdpa, perm=[0, 1, 3, 2])
else:
# Transpose after converting to 3D
key_seq_BH_Skv_Dh = op.Reshape(
key_seq_to_sdpa, pattern.ANY_VALUE, _outputs=["key_seq_BH_Skv_Dh"]
)
key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1])
key_seq_to_sdpa = op.Reshape(
key_seq_BH_Dh_Skv, pattern.ANY_VALUE, _outputs=["key_seq_B_H_Dh_Skv"]
)

# TODO: Remove use_mask once SDPA op is usable
if self._use_mask:
sdpa = op.SDPA(
query_BHSDh_emb,
key_seq_to_sdpa,
value_seq_to_sdpa,
mask,
_domain="ai.onnxruntime.fusion",
_domain="ai.onnxruntime._fusion",
)
else:
sdpa = op.SDPA(
query_BHSDh_emb,
key_seq_to_sdpa,
value_seq_to_sdpa,
_domain="ai.onnxruntime.fusion",
_domain="ai.onnxruntime._fusion",
)

# Transpose attention back to (B, S, H, D/H)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/mha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_whisper_encoder(self):
original_outputs = ort_run("original", model, inputs)

# Fuse SDPA and MHA
sdpa_count = xformers.fuse_sdpa(model)
sdpa_count = xformers.fuse_sdpa(model, debug=True)
self.assertGreater(sdpa_count, 0)
model = common_passes.ShapeInferencePass()(model).model
mha_count = xformers.fuse_mha1(model)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult:
def rewrite(self, op, x, cos, sin, **_):
num_heads = x.shape[1]
return op.RotaryEmbedding(
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion"
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime._fusion"
)


Expand Down
Loading
Loading