Skip to content

Add fusion rule to fuse (query, key, value) to a packed QKV for GQA #2174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from onnxscript.rewriter.ort_fusions.attention import fuse_attention
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
Expand Down Expand Up @@ -77,6 +78,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
# If no MHA fusion was applied, we can try the GQA fusion.
# and avoid trying the attention fusion.
fusion_count["gqa"] = fuse_gqa(model)
fusion_count["packed_qkv_for_gqa"] = fuse_qkv_gqa(model)
fusion_count["attention"] = 0
else:
fusion_count["attention"] = fuse_attention(model)
Expand Down
202 changes: 202 additions & 0 deletions onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from typing import Sequence, Union

import onnxscript.ir as ir
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern

Dim = Union[int, ir.SymbolicDim]


class PackedQKVForGQAFusion(pattern.RewriteRuleClassBase):
def __init__(self):
super().__init__("PackedQKVForGQA", remove_nodes=False)

def pattern(
self,
op,
packed_qkv,
past_key,
past_value,
seqlens_k,
total_seq_length,
cos,
sin,
q_num_heads,
kv_num_heads,
interleaved,
start1,
end1,
start2,
end2,
start3,
end3,
):
"""Pattern to detect sliced Q, K, V passed to GQA and replace with packed QKV."""

# Slice packed QKV into query, key, and value
query_BSD = op.Slice(packed_qkv, start1, end1, [2], [1], _outputs=["query_sliced"])
key_BSDkv = op.Slice(packed_qkv, start2, end2, [2], [1], _outputs=["key_sliced"])
value_BSDkv = op.Slice(packed_qkv, start3, end3, [2], [1], _outputs=["value_sliced"])

# Pass sliced Q, K, V to GroupQueryAttention
return op.GroupQueryAttention(
query_BSD,
key_BSDkv,
value_BSDkv,
past_key,
past_value,
seqlens_k,
total_seq_length,
cos,
sin,
# mask, # TODO: this is not a valid input for GQA
num_heads=q_num_heads,
kv_num_heads=kv_num_heads,
do_rotary=1,
rotary_interleaved=interleaved,
# skipped optional attributes: local_window_size, scale, smooth_softmax, softcap
_domain="com.microsoft",
_outputs=3,
)

def check(
self,
op,
packed_qkv,
query_sliced,
key_sliced,
value_sliced,
q_num_heads,
kv_num_heads,
start1,
end1,
start2,
end2,
start3,
end3,
**_,
):
check_result = pattern.MatchResult()
self.bindings: dict[str, Dim] = {}

def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
return not _fusion_utils._check_shape(self.bindings, val, dims)

# Check that if x is being split into q, k, v correctly
# based on hidden sizes
if packed_qkv is None or packed_qkv.shape is None or len(packed_qkv.shape) != 3:
return check_result.fail("packed_qkv is not a 3D tensor.", packed_qkv)
hidden_size = packed_qkv.shape[2]
if not isinstance(hidden_size, int):
return check_result.fail("Hidden size is not an integer.", packed_qkv)
q_nh = q_num_heads.value
kv_nh = kv_num_heads.value
if not isinstance(q_nh, int) or not isinstance(kv_nh, int):
return check_result.fail(
"Could not determine the number of heads for query, key and value.",
)
head_size = hidden_size // (q_nh + (2 * kv_nh))
q_hidden_size = head_size * q_nh
kv_hidden_size = head_size * kv_nh
if not (
_ir_utils.is_singleton_value(start1, 0)
and _ir_utils.is_singleton_value(end1, q_hidden_size)
and _ir_utils.is_singleton_value(start2, q_hidden_size)
and _ir_utils.is_singleton_value(end2, (q_hidden_size + kv_hidden_size))
and _ir_utils.is_singleton_value(start3, (q_hidden_size + kv_hidden_size))
and _ir_utils.is_singleton_value(end3, lambda x: x >= hidden_size)
):
return check_result.fail(
"packed_qkv is not being split into q, k, v correctly based on hidden sizes.",
packed_qkv,
)

# Check packed_qkv shape (B, S, D)
if no_match(packed_qkv, ["B", "S", "D"]):
return check_result.fail(
f"Shape mismatch: {packed_qkv} does not match expected dimensions ['B', 'S', 'D']",
packed_qkv,
)

# Check query, key, and value shapes (B, S, Dh)
if no_match(query_sliced, ["B", "S", "Dq"]):
return check_result.fail(
f"Shape mismatch: {query_sliced} does not match expected dimensions ['B', 'S', 'Dq']",
query_sliced,
)
if no_match(key_sliced, ["B", "S", "Dkv"]):
return check_result.fail(
f"Shape mismatch: {key_sliced} does not match expected dimensions ['B', 'S', 'Dkv']",
key_sliced,
)
if no_match(value_sliced, ["B", "S", "Dkv"]):
return check_result.fail(
f"Shape mismatch: {value_sliced} does not match expected dimensions ['B', 'S', 'Dkv']",
value_sliced,
)

# Ensure Dh = Dg + 2*Dkv
D = self.bindings.get("D")
Dq = self.bindings.get("Dq")
Dkv = self.bindings.get("Dkv")

if not isinstance(D, int) or not isinstance(Dq, int) or not isinstance(Dkv, int):
return check_result.fail(
"Could not determine the hidden sizes of query, key, and value.",
)

if Dq + (2 * Dkv) != D: # type: ignore[operator]
return check_result.fail(
f"Hidden size of query, key and value do not add up to hidden size: {D} != {Dq} + (2 * {Dkv})",
)

return True

def rewrite(
self,
op,
packed_qkv,
past_key,
past_value,
seqlens_k,
total_seq_length,
cos,
sin,
q_num_heads,
kv_num_heads,
interleaved,
**_,
):
"""Rewrite the sliced Q, K, V into a packed QKV MatMul input for GQA."""

# Pass packed QKV directly to GroupQueryAttention
return op.GroupQueryAttention(
packed_qkv,
None,
None,
past_key,
past_value,
seqlens_k,
total_seq_length,
cos,
sin,
num_heads=q_num_heads,
kv_num_heads=kv_num_heads,
do_rotary=1,
rotary_interleaved=interleaved,
_domain="com.microsoft",
_outputs=3,
)


# Define the fusion rule
packed_qkv_for_gqa_rule = PackedQKVForGQAFusion.rule()

# Add the rule to the GQA rewrite rule set
fuse_qkv_gqa_rules = pattern.RewriteRuleSet([packed_qkv_for_gqa_rule])

# Apply the fusion rules
fuse_qkv_gqa = _fusion_utils.apply_fusion_rules(fuse_qkv_gqa_rules)
141 changes: 141 additions & 0 deletions onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np
import onnxruntime as ort

import onnxscript
import onnxscript.ir as ir
import onnxscript.ir.passes.common.shape_inference as shape_inference
import onnxscript.optimizer
from onnxscript import FLOAT, INT32, script
from onnxscript import opset18 as op
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa

msft_op = onnxscript.values.Opset("com.microsoft", 1)

# Test case for fusion of separate query, key and value inputs
# into a single packed QKV input for the GroupQueryAttention operator.


class PackedQKVforGQAFusionTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Config parameters
self.batchsize = 1
self.seqlen = 8
self.kv_seqlen = self.seqlen
self.past_seqlen = 16
self.head_size = 16
self.q_num_heads = 20
self.kv_num_heads = 10

# Computed config parameters
self.q_hidden_size = self.head_size * self.q_num_heads
self.kv_hidden_size = self.head_size * self.kv_num_heads
self.hidden_size = self.q_hidden_size + self.kv_hidden_size + self.kv_hidden_size

# Abbreviations
B = self.batchsize
S = self.seqlen
P = self.past_seqlen
D = self.hidden_size
Dh = self.head_size
Hkv = self.kv_num_heads
total_seqlen = S + P
max_seqlen = total_seqlen

self.input_types = (
FLOAT["B", "S", D], # packed_qkv
FLOAT["B", Hkv, "P", Dh], # past_key
FLOAT["B", Hkv, "P", Dh], # past_value
INT32["B"], # seqlens_k
INT32[1], # total_sequence_length
FLOAT["max_seqlen", Dh // 2], # cos
FLOAT["max_seqlen", Dh // 2], # sin
)
self.output_types = (
FLOAT["B", "S", D], # attention
FLOAT["B", Hkv, "T", Dh], # present_key
FLOAT["B", Hkv, "T", Dh], # present_value
)

self.inputs = {
"packed_qkv": np.random.rand(B, S, D).astype(np.float32),
"past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32),
"past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32),
"seqlens_k": np.full((B,), total_seqlen - 1, dtype=np.int32),
"total_sequence_length": np.array([total_seqlen], dtype=np.int32),
"cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32),
"sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32),
}

def source_model_script(self):
Hq = self.q_num_heads
Hkv = self.kv_num_heads

@script()
def gqa(packed_qkv, past_key, past_value, seqlens_k, total_sequence_length, cos, sin):
# Slice packed_qkv into query, key and value
query_BSD = op.Slice(packed_qkv, [0], [320], [2], [1])
key_BSDkv = op.Slice(packed_qkv, [320], [480], [2], [1])
value_BSDkv = op.Slice(packed_qkv, [480], [640], [2], [1])

attn, past_key, past_value = msft_op.GroupQueryAttention(
query_BSD,
key_BSDkv,
value_BSDkv,
past_key,
past_value,
seqlens_k,
total_sequence_length,
cos,
sin,
num_heads=Hq,
kv_num_heads=Hkv,
do_rotary=1,
rotary_interleaved=0,
)
return attn, past_key, past_value

return gqa

def test_fuse_packed_qkv_for_gqa(self):
"""
Test that fusion from query, key and value to a packed QKV for GQA
is successful on source model and produces an equivalent model.
"""
inputs = self.inputs

source_model = self.source_model_script().to_model_proto(
input_types=self.input_types,
output_types=self.output_types,
)
session = ort.InferenceSession(
source_model.SerializeToString(), providers=("CPUExecutionProvider",)
)
source_model_outputs = session.run(None, inputs)

source_model_ir = ir.serde.from_proto(source_model)
inferred_model = shape_inference.infer_shapes(source_model_ir)
onnxscript.optimizer.optimize(inferred_model)

count = fuse_qkv_gqa(inferred_model, debug=True)
self.assertEqual(count, 1)

fused_model = ir.serde.to_proto(inferred_model)
session = ort.InferenceSession(
fused_model.SerializeToString(), providers=("CPUExecutionProvider",)
)
fused_model_outputs = session.run(None, inputs)

self.assertEqual(len(fused_model_outputs), len(source_model_outputs))
assert_allclose(fused_model_outputs, source_model_outputs)


if __name__ == "__main__":
unittest.main()
Loading