Skip to content

Commit 0deb51b

Browse files
Add fusion rule to fuse (query, key, value) to a packed QKV for GQA (#2174)
1 parent df26586 commit 0deb51b

File tree

3 files changed

+345
-0
lines changed

3 files changed

+345
-0
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from onnxscript.rewriter.ort_fusions.attention import fuse_attention
1616
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
17+
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
1718
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
1819
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
1920
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
@@ -77,6 +78,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
7778
# If no MHA fusion was applied, we can try the GQA fusion.
7879
# and avoid trying the attention fusion.
7980
fusion_count["gqa"] = fuse_gqa(model)
81+
fusion_count["packed_qkv_for_gqa"] = fuse_qkv_gqa(model)
8082
fusion_count["attention"] = 0
8183
else:
8284
fusion_count["attention"] = fuse_attention(model)
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
from typing import Sequence, Union
6+
7+
import onnxscript.ir as ir
8+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
9+
10+
Dim = Union[int, ir.SymbolicDim]
11+
12+
13+
class PackedQKVForGQAFusion(pattern.RewriteRuleClassBase):
14+
def __init__(self):
15+
super().__init__("PackedQKVForGQA", remove_nodes=False)
16+
17+
def pattern(
18+
self,
19+
op,
20+
packed_qkv,
21+
past_key,
22+
past_value,
23+
seqlens_k,
24+
total_seq_length,
25+
cos,
26+
sin,
27+
q_num_heads,
28+
kv_num_heads,
29+
interleaved,
30+
start1,
31+
end1,
32+
start2,
33+
end2,
34+
start3,
35+
end3,
36+
):
37+
"""Pattern to detect sliced Q, K, V passed to GQA and replace with packed QKV."""
38+
39+
# Slice packed QKV into query, key, and value
40+
query_BSD = op.Slice(packed_qkv, start1, end1, [2], [1], _outputs=["query_sliced"])
41+
key_BSDkv = op.Slice(packed_qkv, start2, end2, [2], [1], _outputs=["key_sliced"])
42+
value_BSDkv = op.Slice(packed_qkv, start3, end3, [2], [1], _outputs=["value_sliced"])
43+
44+
# Pass sliced Q, K, V to GroupQueryAttention
45+
return op.GroupQueryAttention(
46+
query_BSD,
47+
key_BSDkv,
48+
value_BSDkv,
49+
past_key,
50+
past_value,
51+
seqlens_k,
52+
total_seq_length,
53+
cos,
54+
sin,
55+
# mask, # TODO: this is not a valid input for GQA
56+
num_heads=q_num_heads,
57+
kv_num_heads=kv_num_heads,
58+
do_rotary=1,
59+
rotary_interleaved=interleaved,
60+
# skipped optional attributes: local_window_size, scale, smooth_softmax, softcap
61+
_domain="com.microsoft",
62+
_outputs=3,
63+
)
64+
65+
def check(
66+
self,
67+
op,
68+
packed_qkv,
69+
query_sliced,
70+
key_sliced,
71+
value_sliced,
72+
q_num_heads,
73+
kv_num_heads,
74+
start1,
75+
end1,
76+
start2,
77+
end2,
78+
start3,
79+
end3,
80+
**_,
81+
):
82+
check_result = pattern.MatchResult()
83+
self.bindings: dict[str, Dim] = {}
84+
85+
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
86+
return not _fusion_utils._check_shape(self.bindings, val, dims)
87+
88+
# Check that if x is being split into q, k, v correctly
89+
# based on hidden sizes
90+
if packed_qkv is None or packed_qkv.shape is None or len(packed_qkv.shape) != 3:
91+
return check_result.fail("packed_qkv is not a 3D tensor.", packed_qkv)
92+
hidden_size = packed_qkv.shape[2]
93+
if not isinstance(hidden_size, int):
94+
return check_result.fail("Hidden size is not an integer.", packed_qkv)
95+
q_nh = q_num_heads.value
96+
kv_nh = kv_num_heads.value
97+
if not isinstance(q_nh, int) or not isinstance(kv_nh, int):
98+
return check_result.fail(
99+
"Could not determine the number of heads for query, key and value.",
100+
)
101+
head_size = hidden_size // (q_nh + (2 * kv_nh))
102+
q_hidden_size = head_size * q_nh
103+
kv_hidden_size = head_size * kv_nh
104+
if not (
105+
_ir_utils.is_singleton_value(start1, 0)
106+
and _ir_utils.is_singleton_value(end1, q_hidden_size)
107+
and _ir_utils.is_singleton_value(start2, q_hidden_size)
108+
and _ir_utils.is_singleton_value(end2, (q_hidden_size + kv_hidden_size))
109+
and _ir_utils.is_singleton_value(start3, (q_hidden_size + kv_hidden_size))
110+
and _ir_utils.is_singleton_value(end3, lambda x: x >= hidden_size)
111+
):
112+
return check_result.fail(
113+
"packed_qkv is not being split into q, k, v correctly based on hidden sizes.",
114+
packed_qkv,
115+
)
116+
117+
# Check packed_qkv shape (B, S, D)
118+
if no_match(packed_qkv, ["B", "S", "D"]):
119+
return check_result.fail(
120+
f"Shape mismatch: {packed_qkv} does not match expected dimensions ['B', 'S', 'D']",
121+
packed_qkv,
122+
)
123+
124+
# Check query, key, and value shapes (B, S, Dh)
125+
if no_match(query_sliced, ["B", "S", "Dq"]):
126+
return check_result.fail(
127+
f"Shape mismatch: {query_sliced} does not match expected dimensions ['B', 'S', 'Dq']",
128+
query_sliced,
129+
)
130+
if no_match(key_sliced, ["B", "S", "Dkv"]):
131+
return check_result.fail(
132+
f"Shape mismatch: {key_sliced} does not match expected dimensions ['B', 'S', 'Dkv']",
133+
key_sliced,
134+
)
135+
if no_match(value_sliced, ["B", "S", "Dkv"]):
136+
return check_result.fail(
137+
f"Shape mismatch: {value_sliced} does not match expected dimensions ['B', 'S', 'Dkv']",
138+
value_sliced,
139+
)
140+
141+
# Ensure Dh = Dg + 2*Dkv
142+
D = self.bindings.get("D")
143+
Dq = self.bindings.get("Dq")
144+
Dkv = self.bindings.get("Dkv")
145+
146+
if not isinstance(D, int) or not isinstance(Dq, int) or not isinstance(Dkv, int):
147+
return check_result.fail(
148+
"Could not determine the hidden sizes of query, key, and value.",
149+
)
150+
151+
if Dq + (2 * Dkv) != D: # type: ignore[operator]
152+
return check_result.fail(
153+
f"Hidden size of query, key and value do not add up to hidden size: {D} != {Dq} + (2 * {Dkv})",
154+
)
155+
156+
return True
157+
158+
def rewrite(
159+
self,
160+
op,
161+
packed_qkv,
162+
past_key,
163+
past_value,
164+
seqlens_k,
165+
total_seq_length,
166+
cos,
167+
sin,
168+
q_num_heads,
169+
kv_num_heads,
170+
interleaved,
171+
**_,
172+
):
173+
"""Rewrite the sliced Q, K, V into a packed QKV MatMul input for GQA."""
174+
175+
# Pass packed QKV directly to GroupQueryAttention
176+
return op.GroupQueryAttention(
177+
packed_qkv,
178+
None,
179+
None,
180+
past_key,
181+
past_value,
182+
seqlens_k,
183+
total_seq_length,
184+
cos,
185+
sin,
186+
num_heads=q_num_heads,
187+
kv_num_heads=kv_num_heads,
188+
do_rotary=1,
189+
rotary_interleaved=interleaved,
190+
_domain="com.microsoft",
191+
_outputs=3,
192+
)
193+
194+
195+
# Define the fusion rule
196+
packed_qkv_for_gqa_rule = PackedQKVForGQAFusion.rule()
197+
198+
# Add the rule to the GQA rewrite rule set
199+
fuse_qkv_gqa_rules = pattern.RewriteRuleSet([packed_qkv_for_gqa_rule])
200+
201+
# Apply the fusion rules
202+
fuse_qkv_gqa = _fusion_utils.apply_fusion_rules(fuse_qkv_gqa_rules)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import numpy as np
8+
import onnxruntime as ort
9+
10+
import onnxscript
11+
import onnxscript.ir as ir
12+
import onnxscript.ir.passes.common.shape_inference as shape_inference
13+
import onnxscript.optimizer
14+
from onnxscript import FLOAT, INT32, script
15+
from onnxscript import opset18 as op
16+
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose
17+
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
18+
19+
msft_op = onnxscript.values.Opset("com.microsoft", 1)
20+
21+
# Test case for fusion of separate query, key and value inputs
22+
# into a single packed QKV input for the GroupQueryAttention operator.
23+
24+
25+
class PackedQKVforGQAFusionTest(unittest.TestCase):
26+
def __init__(self, *args, **kwargs):
27+
super().__init__(*args, **kwargs)
28+
# Config parameters
29+
self.batchsize = 1
30+
self.seqlen = 8
31+
self.kv_seqlen = self.seqlen
32+
self.past_seqlen = 16
33+
self.head_size = 16
34+
self.q_num_heads = 20
35+
self.kv_num_heads = 10
36+
37+
# Computed config parameters
38+
self.q_hidden_size = self.head_size * self.q_num_heads
39+
self.kv_hidden_size = self.head_size * self.kv_num_heads
40+
self.hidden_size = self.q_hidden_size + self.kv_hidden_size + self.kv_hidden_size
41+
42+
# Abbreviations
43+
B = self.batchsize
44+
S = self.seqlen
45+
P = self.past_seqlen
46+
D = self.hidden_size
47+
Dh = self.head_size
48+
Hkv = self.kv_num_heads
49+
total_seqlen = S + P
50+
max_seqlen = total_seqlen
51+
52+
self.input_types = (
53+
FLOAT["B", "S", D], # packed_qkv
54+
FLOAT["B", Hkv, "P", Dh], # past_key
55+
FLOAT["B", Hkv, "P", Dh], # past_value
56+
INT32["B"], # seqlens_k
57+
INT32[1], # total_sequence_length
58+
FLOAT["max_seqlen", Dh // 2], # cos
59+
FLOAT["max_seqlen", Dh // 2], # sin
60+
)
61+
self.output_types = (
62+
FLOAT["B", "S", D], # attention
63+
FLOAT["B", Hkv, "T", Dh], # present_key
64+
FLOAT["B", Hkv, "T", Dh], # present_value
65+
)
66+
67+
self.inputs = {
68+
"packed_qkv": np.random.rand(B, S, D).astype(np.float32),
69+
"past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32),
70+
"past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32),
71+
"seqlens_k": np.full((B,), total_seqlen - 1, dtype=np.int32),
72+
"total_sequence_length": np.array([total_seqlen], dtype=np.int32),
73+
"cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32),
74+
"sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32),
75+
}
76+
77+
def source_model_script(self):
78+
Hq = self.q_num_heads
79+
Hkv = self.kv_num_heads
80+
81+
@script()
82+
def gqa(packed_qkv, past_key, past_value, seqlens_k, total_sequence_length, cos, sin):
83+
# Slice packed_qkv into query, key and value
84+
query_BSD = op.Slice(packed_qkv, [0], [320], [2], [1])
85+
key_BSDkv = op.Slice(packed_qkv, [320], [480], [2], [1])
86+
value_BSDkv = op.Slice(packed_qkv, [480], [640], [2], [1])
87+
88+
attn, past_key, past_value = msft_op.GroupQueryAttention(
89+
query_BSD,
90+
key_BSDkv,
91+
value_BSDkv,
92+
past_key,
93+
past_value,
94+
seqlens_k,
95+
total_sequence_length,
96+
cos,
97+
sin,
98+
num_heads=Hq,
99+
kv_num_heads=Hkv,
100+
do_rotary=1,
101+
rotary_interleaved=0,
102+
)
103+
return attn, past_key, past_value
104+
105+
return gqa
106+
107+
def test_fuse_packed_qkv_for_gqa(self):
108+
"""
109+
Test that fusion from query, key and value to a packed QKV for GQA
110+
is successful on source model and produces an equivalent model.
111+
"""
112+
inputs = self.inputs
113+
114+
source_model = self.source_model_script().to_model_proto(
115+
input_types=self.input_types,
116+
output_types=self.output_types,
117+
)
118+
session = ort.InferenceSession(
119+
source_model.SerializeToString(), providers=("CPUExecutionProvider",)
120+
)
121+
source_model_outputs = session.run(None, inputs)
122+
123+
source_model_ir = ir.serde.from_proto(source_model)
124+
inferred_model = shape_inference.infer_shapes(source_model_ir)
125+
onnxscript.optimizer.optimize(inferred_model)
126+
127+
count = fuse_qkv_gqa(inferred_model, debug=True)
128+
self.assertEqual(count, 1)
129+
130+
fused_model = ir.serde.to_proto(inferred_model)
131+
session = ort.InferenceSession(
132+
fused_model.SerializeToString(), providers=("CPUExecutionProvider",)
133+
)
134+
fused_model_outputs = session.run(None, inputs)
135+
136+
self.assertEqual(len(fused_model_outputs), len(source_model_outputs))
137+
assert_allclose(fused_model_outputs, source_model_outputs)
138+
139+
140+
if __name__ == "__main__":
141+
unittest.main()

0 commit comments

Comments
 (0)