diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 860d6b366e..a72b107eea 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -14,6 +14,7 @@ ) from onnxscript.rewriter.ort_fusions.attention import fuse_attention from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa from onnxscript.rewriter.ort_fusions.mha import fuse_mha @@ -77,6 +78,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]: # If no MHA fusion was applied, we can try the GQA fusion. # and avoid trying the attention fusion. fusion_count["gqa"] = fuse_gqa(model) + fusion_count["packed_qkv_for_gqa"] = fuse_qkv_gqa(model) fusion_count["attention"] = 0 else: fusion_count["attention"] = fuse_attention(model) diff --git a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py new file mode 100644 index 0000000000..75c4f66f9d --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import onnxscript.ir as ir +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +Dim = Union[int, ir.SymbolicDim] + + +class PackedQKVForGQAFusion(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("PackedQKVForGQA", remove_nodes=False) + + def pattern( + self, + op, + packed_qkv, + past_key, + past_value, + seqlens_k, + total_seq_length, + cos, + sin, + q_num_heads, + kv_num_heads, + interleaved, + start1, + end1, + start2, + end2, + start3, + end3, + ): + """Pattern to detect sliced Q, K, V passed to GQA and replace with packed QKV.""" + + # Slice packed QKV into query, key, and value + query_BSD = op.Slice(packed_qkv, start1, end1, [2], [1], _outputs=["query_sliced"]) + key_BSDkv = op.Slice(packed_qkv, start2, end2, [2], [1], _outputs=["key_sliced"]) + value_BSDkv = op.Slice(packed_qkv, start3, end3, [2], [1], _outputs=["value_sliced"]) + + # Pass sliced Q, K, V to GroupQueryAttention + return op.GroupQueryAttention( + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + seqlens_k, + total_seq_length, + cos, + sin, + # mask, # TODO: this is not a valid input for GQA + num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + do_rotary=1, + rotary_interleaved=interleaved, + # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap + _domain="com.microsoft", + _outputs=3, + ) + + def check( + self, + op, + packed_qkv, + query_sliced, + key_sliced, + value_sliced, + q_num_heads, + kv_num_heads, + start1, + end1, + start2, + end2, + start3, + end3, + **_, + ): + check_result = pattern.MatchResult() + self.bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils._check_shape(self.bindings, val, dims) + + # Check that if x is being split into q, k, v correctly + # based on hidden sizes + if packed_qkv is None or packed_qkv.shape is None or len(packed_qkv.shape) != 3: + return check_result.fail("packed_qkv is not a 3D tensor.", packed_qkv) + hidden_size = packed_qkv.shape[2] + if not isinstance(hidden_size, int): + return check_result.fail("Hidden size is not an integer.", packed_qkv) + q_nh = q_num_heads.value + kv_nh = kv_num_heads.value + if not isinstance(q_nh, int) or not isinstance(kv_nh, int): + return check_result.fail( + "Could not determine the number of heads for query, key and value.", + ) + head_size = hidden_size // (q_nh + (2 * kv_nh)) + q_hidden_size = head_size * q_nh + kv_hidden_size = head_size * kv_nh + if not ( + _ir_utils.is_singleton_value(start1, 0) + and _ir_utils.is_singleton_value(end1, q_hidden_size) + and _ir_utils.is_singleton_value(start2, q_hidden_size) + and _ir_utils.is_singleton_value(end2, (q_hidden_size + kv_hidden_size)) + and _ir_utils.is_singleton_value(start3, (q_hidden_size + kv_hidden_size)) + and _ir_utils.is_singleton_value(end3, lambda x: x >= hidden_size) + ): + return check_result.fail( + "packed_qkv is not being split into q, k, v correctly based on hidden sizes.", + packed_qkv, + ) + + # Check packed_qkv shape (B, S, D) + if no_match(packed_qkv, ["B", "S", "D"]): + return check_result.fail( + f"Shape mismatch: {packed_qkv} does not match expected dimensions ['B', 'S', 'D']", + packed_qkv, + ) + + # Check query, key, and value shapes (B, S, Dh) + if no_match(query_sliced, ["B", "S", "Dq"]): + return check_result.fail( + f"Shape mismatch: {query_sliced} does not match expected dimensions ['B', 'S', 'Dq']", + query_sliced, + ) + if no_match(key_sliced, ["B", "S", "Dkv"]): + return check_result.fail( + f"Shape mismatch: {key_sliced} does not match expected dimensions ['B', 'S', 'Dkv']", + key_sliced, + ) + if no_match(value_sliced, ["B", "S", "Dkv"]): + return check_result.fail( + f"Shape mismatch: {value_sliced} does not match expected dimensions ['B', 'S', 'Dkv']", + value_sliced, + ) + + # Ensure Dh = Dg + 2*Dkv + D = self.bindings.get("D") + Dq = self.bindings.get("Dq") + Dkv = self.bindings.get("Dkv") + + if not isinstance(D, int) or not isinstance(Dq, int) or not isinstance(Dkv, int): + return check_result.fail( + "Could not determine the hidden sizes of query, key, and value.", + ) + + if Dq + (2 * Dkv) != D: # type: ignore[operator] + return check_result.fail( + f"Hidden size of query, key and value do not add up to hidden size: {D} != {Dq} + (2 * {Dkv})", + ) + + return True + + def rewrite( + self, + op, + packed_qkv, + past_key, + past_value, + seqlens_k, + total_seq_length, + cos, + sin, + q_num_heads, + kv_num_heads, + interleaved, + **_, + ): + """Rewrite the sliced Q, K, V into a packed QKV MatMul input for GQA.""" + + # Pass packed QKV directly to GroupQueryAttention + return op.GroupQueryAttention( + packed_qkv, + None, + None, + past_key, + past_value, + seqlens_k, + total_seq_length, + cos, + sin, + num_heads=q_num_heads, + kv_num_heads=kv_num_heads, + do_rotary=1, + rotary_interleaved=interleaved, + _domain="com.microsoft", + _outputs=3, + ) + + +# Define the fusion rule +packed_qkv_for_gqa_rule = PackedQKVForGQAFusion.rule() + +# Add the rule to the GQA rewrite rule set +fuse_qkv_gqa_rules = pattern.RewriteRuleSet([packed_qkv_for_gqa_rule]) + +# Apply the fusion rules +fuse_qkv_gqa = _fusion_utils.apply_fusion_rules(fuse_qkv_gqa_rules) diff --git a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py new file mode 100644 index 0000000000..9559ca1925 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np +import onnxruntime as ort + +import onnxscript +import onnxscript.ir as ir +import onnxscript.ir.passes.common.shape_inference as shape_inference +import onnxscript.optimizer +from onnxscript import FLOAT, INT32, script +from onnxscript import opset18 as op +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose +from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa + +msft_op = onnxscript.values.Opset("com.microsoft", 1) + +# Test case for fusion of separate query, key and value inputs +# into a single packed QKV input for the GroupQueryAttention operator. + + +class PackedQKVforGQAFusionTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Config parameters + self.batchsize = 1 + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.past_seqlen = 16 + self.head_size = 16 + self.q_num_heads = 20 + self.kv_num_heads = 10 + + # Computed config parameters + self.q_hidden_size = self.head_size * self.q_num_heads + self.kv_hidden_size = self.head_size * self.kv_num_heads + self.hidden_size = self.q_hidden_size + self.kv_hidden_size + self.kv_hidden_size + + # Abbreviations + B = self.batchsize + S = self.seqlen + P = self.past_seqlen + D = self.hidden_size + Dh = self.head_size + Hkv = self.kv_num_heads + total_seqlen = S + P + max_seqlen = total_seqlen + + self.input_types = ( + FLOAT["B", "S", D], # packed_qkv + FLOAT["B", Hkv, "P", Dh], # past_key + FLOAT["B", Hkv, "P", Dh], # past_value + INT32["B"], # seqlens_k + INT32[1], # total_sequence_length + FLOAT["max_seqlen", Dh // 2], # cos + FLOAT["max_seqlen", Dh // 2], # sin + ) + self.output_types = ( + FLOAT["B", "S", D], # attention + FLOAT["B", Hkv, "T", Dh], # present_key + FLOAT["B", Hkv, "T", Dh], # present_value + ) + + self.inputs = { + "packed_qkv": np.random.rand(B, S, D).astype(np.float32), + "past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "seqlens_k": np.full((B,), total_seqlen - 1, dtype=np.int32), + "total_sequence_length": np.array([total_seqlen], dtype=np.int32), + "cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + "sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + } + + def source_model_script(self): + Hq = self.q_num_heads + Hkv = self.kv_num_heads + + @script() + def gqa(packed_qkv, past_key, past_value, seqlens_k, total_sequence_length, cos, sin): + # Slice packed_qkv into query, key and value + query_BSD = op.Slice(packed_qkv, [0], [320], [2], [1]) + key_BSDkv = op.Slice(packed_qkv, [320], [480], [2], [1]) + value_BSDkv = op.Slice(packed_qkv, [480], [640], [2], [1]) + + attn, past_key, past_value = msft_op.GroupQueryAttention( + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos, + sin, + num_heads=Hq, + kv_num_heads=Hkv, + do_rotary=1, + rotary_interleaved=0, + ) + return attn, past_key, past_value + + return gqa + + def test_fuse_packed_qkv_for_gqa(self): + """ + Test that fusion from query, key and value to a packed QKV for GQA + is successful on source model and produces an equivalent model. + """ + inputs = self.inputs + + source_model = self.source_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + source_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + source_model_outputs = session.run(None, inputs) + + source_model_ir = ir.serde.from_proto(source_model) + inferred_model = shape_inference.infer_shapes(source_model_ir) + onnxscript.optimizer.optimize(inferred_model) + + count = fuse_qkv_gqa(inferred_model, debug=True) + self.assertEqual(count, 1) + + fused_model = ir.serde.to_proto(inferred_model) + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + fused_model_outputs = session.run(None, inputs) + + self.assertEqual(len(fused_model_outputs), len(source_model_outputs)) + assert_allclose(fused_model_outputs, source_model_outputs) + + +if __name__ == "__main__": + unittest.main()