|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import unittest |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import onnx |
| 9 | +import onnxruntime as ort |
| 10 | +import torch |
| 11 | + |
| 12 | +import onnxscript |
| 13 | +import onnxscript.ir as ir |
| 14 | +import onnxscript.ir.passes.common.shape_inference as shape_inference |
| 15 | +import onnxscript.optimizer |
| 16 | +from onnxscript import FLOAT, script |
| 17 | +from onnxscript import opset18 as op |
| 18 | +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose |
| 19 | +from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa |
| 20 | + |
| 21 | +msft_op = onnxscript.values.Opset("com.microsoft", 1) |
| 22 | + |
| 23 | +# Test case for fusion of separate query, key and value inputs |
| 24 | +# into a single packed QKV input for the GroupQueryAttention operator. |
| 25 | + |
| 26 | + |
| 27 | +class PackedQKVforGQAFusionTest(unittest.TestCase): |
| 28 | + def __init__(self, *args, **kwargs): |
| 29 | + super().__init__(*args, **kwargs) |
| 30 | + # Config parameters |
| 31 | + self.batchsize = 1 |
| 32 | + self.seqlen = 8 |
| 33 | + self.kv_seqlen = self.seqlen |
| 34 | + self.past_seqlen = 16 |
| 35 | + self.head_size = 16 |
| 36 | + self.q_num_heads = 20 |
| 37 | + self.kv_num_heads = 10 |
| 38 | + |
| 39 | + # Computed config parameters |
| 40 | + self.q_hidden_size = self.head_size * self.q_num_heads |
| 41 | + self.kv_hidden_size = self.head_size * self.kv_num_heads |
| 42 | + self.hidden_size = self.q_hidden_size + self.kv_hidden_size + self.kv_hidden_size |
| 43 | + |
| 44 | + # Abbreviations |
| 45 | + B = self.batchsize |
| 46 | + S = self.seqlen |
| 47 | + P = self.past_seqlen |
| 48 | + D = self.hidden_size |
| 49 | + Dh = self.head_size |
| 50 | + Hkv = self.kv_num_heads |
| 51 | + total_seqlen = S + P |
| 52 | + max_seqlen = total_seqlen |
| 53 | + |
| 54 | + self.input_types = ( |
| 55 | + FLOAT["B", "S", D], # packed_qkv |
| 56 | + FLOAT["B", Hkv, "P", Dh], # past_key |
| 57 | + FLOAT["B", Hkv, "P", Dh], # past_value |
| 58 | + FLOAT["max_seqlen", Dh // 2], # cos |
| 59 | + FLOAT["max_seqlen", Dh // 2], # sin |
| 60 | + ) |
| 61 | + self.output_types = ( |
| 62 | + FLOAT["B", "S", D], # attention |
| 63 | + FLOAT["B", Hkv, "T", Dh], # present_key |
| 64 | + FLOAT["B", Hkv, "T", Dh], # present_value |
| 65 | + ) |
| 66 | + |
| 67 | + self.inputs = { |
| 68 | + "packed_qkv": np.random.rand(B, S, D).astype(np.float32), |
| 69 | + "past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32), |
| 70 | + "past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32), |
| 71 | + "cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), |
| 72 | + "sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), |
| 73 | + } |
| 74 | + |
| 75 | + def source_model_script(self): |
| 76 | + Hq = self.q_num_heads |
| 77 | + Hkv = self.kv_num_heads |
| 78 | + |
| 79 | + @script() |
| 80 | + def gqa(packed_qkv, past_key, past_value, cos, sin): |
| 81 | + # Generate seqlens_k and total_seqlen inputs for GQA: |
| 82 | + # In this test case, all batch elements have same sequence length. |
| 83 | + S = op.Shape(packed_qkv, start=1, end=2) |
| 84 | + past_seq_length = op.Shape(past_key, start=2, end=3) |
| 85 | + total_seq_length = op.Add(past_seq_length, S) |
| 86 | + total_seqlen_int32 = op.Cast(total_seq_length, to=6) |
| 87 | + total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) |
| 88 | + batchsize = op.Shape(packed_qkv, start=0, end=1) |
| 89 | + seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) |
| 90 | + |
| 91 | + # Slice packed_qkv into query, key and value |
| 92 | + query_BSD = op.Slice(packed_qkv, [0], [320], [2], [1]) |
| 93 | + key_BSDkv = op.Slice(packed_qkv, [320], [480], [2], [1]) |
| 94 | + value_BSDkv = op.Slice(packed_qkv, [480], [640], [2], [1]) |
| 95 | + |
| 96 | + attn, past_key, past_value = msft_op.GroupQueryAttention( |
| 97 | + query_BSD, |
| 98 | + key_BSDkv, |
| 99 | + value_BSDkv, |
| 100 | + past_key, |
| 101 | + past_value, |
| 102 | + seqlens_k, |
| 103 | + total_seqlen_int32, |
| 104 | + cos, |
| 105 | + sin, |
| 106 | + num_heads=Hq, |
| 107 | + kv_num_heads=Hkv, |
| 108 | + do_rotary=1, |
| 109 | + rotary_interleaved=0, |
| 110 | + ) |
| 111 | + return attn, past_key, past_value |
| 112 | + |
| 113 | + return gqa |
| 114 | + |
| 115 | + def test_fuse_packed_qkv_for_gqa(self): |
| 116 | + """ |
| 117 | + Test that fusion from query, key and value to a packed QKV for GQA |
| 118 | + is successful on source model and produces an equivalent model. |
| 119 | + """ |
| 120 | + inputs = self.inputs |
| 121 | + |
| 122 | + source_model = self.source_model_script().to_model_proto( |
| 123 | + input_types=self.input_types, |
| 124 | + output_types=self.output_types, |
| 125 | + ) |
| 126 | + session = ort.InferenceSession( |
| 127 | + source_model.SerializeToString(), providers=("CPUExecutionProvider",) |
| 128 | + ) |
| 129 | + source_model_outputs = session.run(None, inputs) |
| 130 | + |
| 131 | + source_model_ir = ir.serde.from_proto(source_model) |
| 132 | + inferred_model = shape_inference.infer_shapes(source_model_ir) |
| 133 | + onnxscript.optimizer.optimize(inferred_model) |
| 134 | + |
| 135 | + count = fuse_qkv_gqa(inferred_model, debug=True) |
| 136 | + self.assertEqual(count, 1) |
| 137 | + |
| 138 | + fused_model = ir.serde.to_proto(inferred_model) |
| 139 | + session = ort.InferenceSession( |
| 140 | + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) |
| 141 | + ) |
| 142 | + fused_model_outputs = session.run(None, inputs) |
| 143 | + |
| 144 | + self.assertEqual(len(fused_model_outputs), len(source_model_outputs)) |
| 145 | + assert_allclose(fused_model_outputs, source_model_outputs) |
| 146 | + |
| 147 | + |
| 148 | +if __name__ == "__main__": |
| 149 | + unittest.main() |
0 commit comments