Skip to content

Commit 3ae18a7

Browse files
add unit test
1 parent 5b3ebfe commit 3ae18a7

File tree

2 files changed

+158
-3
lines changed

2 files changed

+158
-3
lines changed

onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,15 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
9292
hidden_size = packed_qkv.shape[2]
9393
if not isinstance(hidden_size, int):
9494
return check_result.fail("Hidden size is not an integer.", packed_qkv)
95-
head_size = hidden_size // (q_num_heads + (2 * kv_num_heads))
96-
q_hidden_size = head_size * q_num_heads
97-
kv_hidden_size = head_size * kv_num_heads
95+
q_nh = q_num_heads.value
96+
kv_nh = kv_num_heads.value
97+
if not isinstance(q_nh, int) or not isinstance(kv_nh, int):
98+
return check_result.fail(
99+
"Could not determine the number of heads for query, key and value.",
100+
)
101+
head_size = hidden_size // (q_nh + (2 * kv_nh))
102+
q_hidden_size = head_size * q_nh
103+
kv_hidden_size = head_size * kv_nh
98104
if not (
99105
_ir_utils.is_singleton_value(start1, 0)
100106
and _ir_utils.is_singleton_value(end1, q_hidden_size)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import numpy as np
8+
import onnx
9+
import onnxruntime as ort
10+
import torch
11+
12+
import onnxscript
13+
import onnxscript.ir as ir
14+
import onnxscript.ir.passes.common.shape_inference as shape_inference
15+
import onnxscript.optimizer
16+
from onnxscript import FLOAT, script
17+
from onnxscript import opset18 as op
18+
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose
19+
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
20+
21+
msft_op = onnxscript.values.Opset("com.microsoft", 1)
22+
23+
# Test case for fusion of separate query, key and value inputs
24+
# into a single packed QKV input for the GroupQueryAttention operator.
25+
26+
27+
class PackedQKVforGQAFusionTest(unittest.TestCase):
28+
def __init__(self, *args, **kwargs):
29+
super().__init__(*args, **kwargs)
30+
# Config parameters
31+
self.batchsize = 1
32+
self.seqlen = 8
33+
self.kv_seqlen = self.seqlen
34+
self.past_seqlen = 16
35+
self.head_size = 16
36+
self.q_num_heads = 20
37+
self.kv_num_heads = 10
38+
39+
# Computed config parameters
40+
self.q_hidden_size = self.head_size * self.q_num_heads
41+
self.kv_hidden_size = self.head_size * self.kv_num_heads
42+
self.hidden_size = self.q_hidden_size + self.kv_hidden_size + self.kv_hidden_size
43+
44+
# Abbreviations
45+
B = self.batchsize
46+
S = self.seqlen
47+
P = self.past_seqlen
48+
D = self.hidden_size
49+
Dh = self.head_size
50+
Hkv = self.kv_num_heads
51+
total_seqlen = S + P
52+
max_seqlen = total_seqlen
53+
54+
self.input_types = (
55+
FLOAT["B", "S", D], # packed_qkv
56+
FLOAT["B", Hkv, "P", Dh], # past_key
57+
FLOAT["B", Hkv, "P", Dh], # past_value
58+
FLOAT["max_seqlen", Dh // 2], # cos
59+
FLOAT["max_seqlen", Dh // 2], # sin
60+
)
61+
self.output_types = (
62+
FLOAT["B", "S", D], # attention
63+
FLOAT["B", Hkv, "T", Dh], # present_key
64+
FLOAT["B", Hkv, "T", Dh], # present_value
65+
)
66+
67+
self.inputs = {
68+
"packed_qkv": np.random.rand(B, S, D).astype(np.float32),
69+
"past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32),
70+
"past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32),
71+
"cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32),
72+
"sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32),
73+
}
74+
75+
def source_model_script(self):
76+
Hq = self.q_num_heads
77+
Hkv = self.kv_num_heads
78+
79+
@script()
80+
def gqa(packed_qkv, past_key, past_value, cos, sin):
81+
# Generate seqlens_k and total_seqlen inputs for GQA:
82+
# In this test case, all batch elements have same sequence length.
83+
S = op.Shape(packed_qkv, start=1, end=2)
84+
past_seq_length = op.Shape(past_key, start=2, end=3)
85+
total_seq_length = op.Add(past_seq_length, S)
86+
total_seqlen_int32 = op.Cast(total_seq_length, to=6)
87+
total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1)
88+
batchsize = op.Shape(packed_qkv, start=0, end=1)
89+
seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize)
90+
91+
# Slice packed_qkv into query, key and value
92+
query_BSD = op.Slice(packed_qkv, [0], [320], [2], [1])
93+
key_BSDkv = op.Slice(packed_qkv, [320], [480], [2], [1])
94+
value_BSDkv = op.Slice(packed_qkv, [480], [640], [2], [1])
95+
96+
attn, past_key, past_value = msft_op.GroupQueryAttention(
97+
query_BSD,
98+
key_BSDkv,
99+
value_BSDkv,
100+
past_key,
101+
past_value,
102+
seqlens_k,
103+
total_seqlen_int32,
104+
cos,
105+
sin,
106+
num_heads=Hq,
107+
kv_num_heads=Hkv,
108+
do_rotary=1,
109+
rotary_interleaved=0,
110+
)
111+
return attn, past_key, past_value
112+
113+
return gqa
114+
115+
def test_fuse_packed_qkv_for_gqa(self):
116+
"""
117+
Test that fusion from query, key and value to a packed QKV for GQA
118+
is successful on source model and produces an equivalent model.
119+
"""
120+
inputs = self.inputs
121+
122+
source_model = self.source_model_script().to_model_proto(
123+
input_types=self.input_types,
124+
output_types=self.output_types,
125+
)
126+
session = ort.InferenceSession(
127+
source_model.SerializeToString(), providers=("CPUExecutionProvider",)
128+
)
129+
source_model_outputs = session.run(None, inputs)
130+
131+
source_model_ir = ir.serde.from_proto(source_model)
132+
inferred_model = shape_inference.infer_shapes(source_model_ir)
133+
onnxscript.optimizer.optimize(inferred_model)
134+
135+
count = fuse_qkv_gqa(inferred_model, debug=True)
136+
self.assertEqual(count, 1)
137+
138+
fused_model = ir.serde.to_proto(inferred_model)
139+
session = ort.InferenceSession(
140+
fused_model.SerializeToString(), providers=("CPUExecutionProvider",)
141+
)
142+
fused_model_outputs = session.run(None, inputs)
143+
144+
self.assertEqual(len(fused_model_outputs), len(source_model_outputs))
145+
assert_allclose(fused_model_outputs, source_model_outputs)
146+
147+
148+
if __name__ == "__main__":
149+
unittest.main()

0 commit comments

Comments
 (0)