diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 5fa7848626..cc58490f63 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -405,17 +405,13 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: shape = _get_input(node, 1) if input is None or shape is None: return None + input_shape = input.shape - if input_shape is None: - return None - # input_shape_dims = list(input_shape.dims) - # if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in input_shape_dims): - # return None shape_value = state.get_shape_value(shape) - if shape_value is None: + + if shape_value is None or input_shape is None: return None - # target_shape_dims = list(shape_value.dims) - # if input_shape_dims == target_shape_dims: + # No need to check for special values like -1, 0, etc. here if _same_shape(input_shape, shape_value): return op.Identity(input) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index f184a2a673..e1a6be338d 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -39,5 +39,8 @@ def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4): np.testing.assert_equal(baseline_output.shape, optimized_output.shape) np.testing.assert_allclose(baseline_output, optimized_output, rtol=rtol, atol=atol) except AssertionError as e: + diff_mask = ~np.isclose(baseline_output, optimized_output, rtol=rtol, atol=atol) + diff = np.where(diff_mask, "X", " ") + print(diff) print(f"Failed for output {i} with rtol={rtol} and atol={atol}\n{e}") raise diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 7de2bfa522..7f761a3744 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -2,148 +2,268 @@ # Licensed under the MIT License. from __future__ import annotations -from onnxscript.rewriter import _fusion_utils, pattern +from typing import Sequence, Union +import numpy as np -class GroupQueryAttention(pattern.RewriteRuleClassBase): - def __init__(self, name: str, *, use_2d_matmul: bool): - super().__init__(name, remove_nodes=False) - self._use_2d_matmul = use_2d_matmul - - def _compute_packed_QKV(self, op, input, weight): - if self._use_2d_matmul: - # Convert batched input of shape (B, S, D) to 2D input (B*S, D) - input = op.Reshape(input, _allow_other_inputs=True) - projected = op.MatMul(input, weight) - if self._use_2d_matmul: - # Convert 2D output back to batched output of shape (B, S, D) - projected = op.Reshape(projected, _allow_other_inputs=True) - # Split combined QKV into Q, K, and V - query_3d = op.Slice(projected, _allow_other_inputs=True) - key_3d = op.Slice(projected, _allow_other_inputs=True) - value_3d = op.Slice(projected, _allow_other_inputs=True) - # Reshape from (B, S, D) to (B, S, H, D/H) - query_4d = op.Reshape( - query_3d, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=["query_mm_reshaped"], - ) - # Transpose from (B, S, H, D/H) to (B, H, S, D/H) - query = op.Transpose(query_4d, perm=[0, 2, 1, 3]) - key_4d = op.Reshape( - key_3d, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=["key_mm_reshaped"], - ) - key = op.Transpose(key_4d, perm=[0, 2, 1, 3]) - value_4d = op.Reshape( - value_3d, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=["value_mm_reshaped"], - ) - value = op.Transpose(value_4d, perm=[0, 2, 1, 3]) +import onnxscript.ir as ir +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import _ir_utils, pattern + +""" +GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different +for query and key/value. + +We use the following abbreviations for the dimensions: +B: Batch size +S: Sequence length (for current query/key/value) + +Hkv: number of heads for key/value +G = number of groups +H: number of heads = G * Hkv + +Dh: head size or embedding dimension per head +D: input embedding dimension (hidden size) = H * Dh +Dkv: key/value hidden size = Hkv * Dh + +T: total sequence length (after concatenation of past and current key/value) +""" + +Dim = Union[int, ir.SymbolicDim] + + +def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): + seq_len = op.Shape(input_ids, end=2, start=1) + seq_len_0D = op.Squeeze(seq_len) + + past_seq_len = op.Shape(past_kv_cache, end=3, start=2) + past_seq_len_0D = op.Squeeze(past_seq_len) + + total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) + total_seq_len = op.Reshape(total_seq_len_0D, [-1]) - return query, key, value + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But using it for pattern-matching against + # generated onnx model. + total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_float32 = float(np.finfo(np.float32).min) + mask_all_min = op.Expand(min_float32, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) + + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + return mask_B1ST + + +class GroupQueryAttention(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("GQA", remove_nodes=False) def pattern( self, op, - input, - qkv_weight, - mask, - cos, - sin, + query_BSD, + key_BSDkv, + value_BSDkv, past_key, past_value, - position_ids, + input_ids, + past_seq_length, + total_seq_length, + cos, + sin, + some_kv_cache, + shape_B111, ): - query, key, value = self._compute_packed_QKV(op, input, qkv_weight) + # Reshape query from (B, S, D) to (B, S, H, D/H) + query_BSHDh = op.Reshape(query_BSD, _allow_other_inputs=True, _outputs=["query_BSHDh"]) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) - query_rope = op.RotaryEmbedding(query, position_ids, cos, sin, _domain="com.microsoft") + # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) + key_BSHkvDh = op.Reshape(key_BSDkv, _allow_other_inputs=True, _outputs=["key_BSHkvDh"]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) - key_rope = op.RotaryEmbedding(key, position_ids, cos, sin, _domain="com.microsoft") - present_key = op.Concat(past_key, key_rope, axis=-2) - # Transpose last two axes of present_key to compute dot-product via matmul. - present_key = op.Transpose(present_key, perm=[0, 1, 3, 2]) + # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) + value_BSHkvDh = op.Reshape( + value_BSDkv, _allow_other_inputs=True, _outputs=["value_BSHkvDh"] + ) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) - present_value = op.Concat(past_value, value, axis=-2) + position_ids = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) - attention = op.SDPA( - query_rope, present_key, present_value, mask, _domain="ai.onnxruntime.fusion" + query_BHSDh_rope = op.RotaryEmbedding( + query_BHSDh, + position_ids_q, + cos, + sin, + _domain="com.microsoft", + _outputs=["query_BHSDh_rope"], ) - # Transpose back to (B, S, H, D/H) - attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + key_BHkvSDh_rope = op.RotaryEmbedding( + key_BHkvSDh, + position_ids_k, + cos, + sin, + _domain="com.microsoft", + _outputs=["key_BHkvSDh_rope"], + ) + + # Concatenate past_key cache and current key, expand across heads + # that share key/value. + + key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) + key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, _allow_other_inputs=True) + key_seq_BHTDh = op.Reshape( + key_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["key_seq_BHTDh"] + ) + + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) + value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, _allow_other_inputs=True) + value_seq_BHTDh = op.Reshape( + value_seq_BHkvGTDh, _allow_other_inputs=True, _outputs=["value_seq_BHTDh"] + ) + + mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) + + key_seq_BHDhT = op.Transpose(key_seq_BHTDh, perm=[0, 1, 3, 2]) + attention_BHSDh = op.SDPA( + query_BHSDh_rope, + key_seq_BHDhT, + value_seq_BHTDh, + mask, + _domain="ai.onnxruntime.fusion", + ) + + # Transpose attention back to (B, S, H, D/H) + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) # Reshape back to (B, S, D) - attention_reshaped = op.Reshape( - attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"] + attention_BSD = op.Reshape( + attention_BSHDh, _allow_other_inputs=True, _outputs=["attention_BSD"] ) - return attention_reshaped, present_key, present_value + return attention_BSD, key_seq_BHkvTDh, value_seq_BHkvTDh def check( self, op, - # query_mm_reshaped, - # key_mm_reshaped, - # value_mm_reshaped, - # key_reshaped, - # key_transposed, - # attention_reshaped, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + query_BHSDh_rope, + key_BHkvSDh_rope, + query_BSHDh, + key_BSHkvDh, **_, - ) -> pattern.MatchResult: # type: ignore[name-defined] - check_result = pattern.MatchResult() - # bindings: dict[str, int] = {} - # status = ( - # _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) - # and _check_shape(bindings, key_mm_reshaped, ["B", "S", "H", "d_h"]) - # and _check_shape(bindings, value_mm_reshaped, ["B", "S", "H", "d_h"]) - # and _check_shape(bindings, key_reshaped, ["B*H", "KVS", "d_h"]) - # and _check_shape(bindings, key_transposed, ["B", "H", "d_h", "KVS"]) - # and _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"]) - # ) - # if not status: - # return False - # if bindings["B"] * bindings["H"] != bindings["B*H"]: - # return False - # if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]: - # return False - return check_result + ): + bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils._check_shape(bindings, val, dims) + + if no_match(query_BSD, ["B", "S", "D"]): + return False + if no_match(key_BSDkv, ["B", "S", "Dkv"]): + return False + if no_match(value_BSDkv, ["B", "S", "Dkv"]): + return False + + if no_match(past_key, ["B", "Hkv", "P", "Dh"]): + return False + if no_match(past_value, ["B", "Hkv", "P", "Dv"]): + return False + + # TODO: verify Reshapes: + # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: + # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: + # or check Reshape's shape-input value + + result = pattern.MatchResult() + num_heads = _ir_utils.get_dim(query_BSHDh, 2) + kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) + if not isinstance(num_heads, int): + return result.fail("Unable to determine num_heads value", query_BSHDh) + if not isinstance(kv_num_heads, int): + return result.fail("Unable to determine kv_num_heads value", key_BSHkvDh) + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + + # Rotary embedding attributes + query_rotary_attributes = query_BHSDh_rope.producer().attributes + key_rotary_attributes = key_BHkvSDh_rope.producer().attributes + query_interleaved = query_rotary_attributes.get("interleaved", 0) + key_interleaved = key_rotary_attributes.get("interleaved", 0) + if query_interleaved != key_interleaved: + return pattern.MatchResult().fail( + "Rotary embedding interleaved attribute mismatch", + [query_BHSDh_rope.producer(), key_BHkvSDh_rope.producer()], + ) + self._interleaved = query_interleaved + + return True def rewrite( self, op, - input, - qkv_weight, - mask, - cos, - sin, + query_BSD, + key_BSDkv, + value_BSDkv, past_key, past_value, - position_ids, - query_mm_reshaped, + total_seq_length, + cos, + sin, **_, ): - num_heads = query_mm_reshaped.shape[2] - qkv = op.MatMul(input, qkv_weight) + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) + one_0D = op.Constant(value_int=1) + one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32) + seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32) + zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) + seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D) + return op.GroupQueryAttention( - qkv, - None, # key - None, # value + query_BSD, + key_BSDkv, + value_BSDkv, past_key, past_value, - # seqlens_k, - # total_sequence_length, + seqlens_k, + total_seq_length_int32, cos, sin, - num_heads=num_heads, + # mask, # TODO: this is not a valid input for GQA + num_heads=self.num_heads, + kv_num_heads=self.kv_num_heads, + do_rotary=1, + rotary_interleaved=self._interleaved, + # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap _domain="com.microsoft", _outputs=3, ) -_rule1 = GroupQueryAttention.rule("MHA_2dmm", use_2d_matmul=False) +_rule1 = GroupQueryAttention.rule() gqa_rules = pattern.RewriteRuleSet([_rule1]) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py new file mode 100644 index 0000000000..4f8f9ab8ba --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -0,0 +1,344 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import math +import unittest + +import numpy as np +import onnx +import onnxruntime as ort +import torch + +import onnxscript +import onnxscript.ir as ir +import onnxscript.ir.passes.common.shape_inference as shape_inference +import onnxscript.optimizer +from onnxscript import FLOAT, script +from onnxscript import opset18 as op +from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose +from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa +from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa + +msft_op = onnxscript.values.Opset("com.microsoft", 1) + +# Test case for GroupQueryAttention (GQA) fusion. + + +class GQAFusionTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Config parameters + self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.past_seqlen = 16 + self.head_size = 16 + self.num_heads = 20 + self.kv_num_heads = 10 + + # Computed config parameters + self.hidden_size = self.head_size * self.num_heads + self.kv_hidden_size = self.head_size * self.kv_num_heads + assert (self.num_heads % self.kv_num_heads) == 0, ( + "num_heads must be divisible by kv_num_heads" + ) + self.num_groups = self.num_heads // self.kv_num_heads + + # Abbreviations + B = self.batchsize + S = self.seqlen + P = self.past_seqlen + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.head_size + Hkv = self.kv_num_heads + total_seqlen = S + P + max_seqlen = total_seqlen + + # Input/output types have some dimensions as dynamic (even though the + # test case instance has specific values above). + self.input_types = ( + FLOAT["B", "S", D], # query + FLOAT["B", "S", Dkv], # key + FLOAT["B", "S", Dkv], # value + FLOAT["B", Hkv, "P", Dh], # past_key + FLOAT["B", Hkv, "P", Dh], # past_value + FLOAT["max_seqlen", Dh // 2], # cos + FLOAT["max_seqlen", Dh // 2], # sin + ) + self.output_types = ( + FLOAT["B", "S", D], # attention + FLOAT["B", Hkv, "T", Dh], # present_key + FLOAT["B", Hkv, "T", Dh], # present_value + ) + + self.inputs = { + "query": np.random.rand(B, S, D).astype(np.float32), + "key": np.random.rand(B, S, Dkv).astype(np.float32), + "value": np.random.rand(B, S, Dkv).astype(np.float32), + "past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + "sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + } + + def target_model_script(self): + H = self.num_heads + Hkv = self.kv_num_heads + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin): + # Generate seqlens_k and total_seqlen inputs for GQA: + # In this test case, all batch elements have same sequence length. + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + total_seqlen_int32 = op.Cast(total_seq_length, to=6) + total_seqlen_int32_minus_1 = op.Sub(total_seqlen_int32, 1) + batchsize = op.Shape(query, start=0, end=1) + seqlens_k = op.Tile(total_seqlen_int32_minus_1, batchsize) + + attn, past_key, past_value = msft_op.GroupQueryAttention( + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_seqlen_int32, + cos, + sin, + num_heads=H, + kv_num_heads=Hkv, + do_rotary=1, + ) + return attn, past_key, past_value + + return gqa + + def source_model_script(self): + scale_factor = math.sqrt(math.sqrt(self.head_size)) + minval = torch.finfo(torch.float32).min + minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) + H = [self.num_heads] + Hkv = [self.kv_num_heads] + Dh = [self.head_size] + G = [self.num_groups] + minus_1 = [-1] # inferred dimension in Reshape op + plus_1 = [1] + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin): + # Shapes used for Reshape ops. Note that we have a few different options on how shapes are + # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate + # existing dimension and one inferred dimension respectively). The following shapes are + # based on what is observed in Phi models generated by the exporter. + B = op.Shape(query, start=0, end=1) + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + # past_seq_length = op.Squeeze(past_seq_length_1D, [0]) + # S_0D = op.Squeeze(S,[0]) + + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSD = op.Concat(B, S, minus_1, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0) + + shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0) + + # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. + # D is different for Q and K/V (not reflected in the names, unfortunately). + # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only + # one sequence length (S) for all Q, K, and V (with no cache). + query_BSHDh = op.Reshape(query, shape_BSHDh) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + # Concat past and do rotary embedding + position_ids_1d = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids_1d, [0]) + position_ids_k = op.Unsqueeze(position_ids_1d, [0]) + + # Note: The above code pattern for position-ids is from exported Phi model. + # However, for use with ORT's RotaryEmbedding it needs the following for batchsize > 1 + # But we currently target batchsize=1 since GQA requires it when there is a past key/value. + # + # position_ids_2d = op.Unsqueeze(position_ids_1d, [0]) + # tile_B_1 = op.Concat(B, plus_1, axis=0) + # position_ids = op.Tile(position_ids_2d, tile_B_1) + + query_BHSDh_rope = msft_op.RotaryEmbedding( + query_BHSDh, + position_ids_q, + cos, + sin, + ) + key_BHkvSDh_rope = msft_op.RotaryEmbedding( + key_BHkvSDh, + position_ids_k, + cos, + sin, + ) + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + + # Now, expand from shared heads to all heads + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) + key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) + + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) + value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) + + # Generate causal mask: + # where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] + seq_len = op.Shape(query, end=2, start=1) + seq_len_0D = op.Squeeze(seq_len) + + past_seq_len_0D = op.Squeeze(past_seq_length) + + total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) + total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But duplicating same logic here. + total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_val = op.Constant(value=minval_tp) + mask_all_min = op.Expand(min_val, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + shape_B111 = op.Concat(B, plus_1, plus_1, plus_1, axis=0) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) + + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + + # Now, compute attention: + key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=scale_factor) + scaled_query = op.Div(query_BHSDh_rope, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask_B1ST) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) + + # Reshape back to BSD format + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) + + return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh + + return gqa + + def test_equivalence(self): + """Test that the source and target models produce the same outputs.""" + inputs = self.inputs + + source_model = self.source_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + source_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + source_model_outputs = session.run(None, inputs) + + target_model = self.target_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + target_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + target_model_outputs = session.run(None, inputs) + + self.assertEqual(len(source_model_outputs), len(target_model_outputs)) + assert_allclose(source_model_outputs, target_model_outputs) + + def test_fusion(self): + """Test that GQA fusion is successful on source model and produces an equivalent model.""" + inputs = self.inputs + + source_model = self.source_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + source_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + source_model_outputs = session.run(None, inputs) + + # Some shapes need to be present in input model for fusion to be successful. + # (i) Shape inference doesn't handle handle ORT contrib ops. + # (ii) TODO: investigate if Reshape(..., ["B", "S", -1, Dh]) handled precisely + # by shape inference. + query_BHSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "query_BHSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.seqlen, self.head_size], + ) + key_BHkvSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "key_BHkvSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.kv_num_heads, self.seqlen, self.head_size], + ) + query_BSHDh_value_info = onnx.helper.make_tensor_value_info( + "query_BSHDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.num_heads, self.head_size], + ) + key_BSHkvDh_value_info = onnx.helper.make_tensor_value_info( + "key_BSHkvDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.kv_num_heads, self.head_size], + ) + source_model.graph.value_info.extend( + [ + query_BHSDh_rope_value_info, + key_BHkvSDh_rope_value_info, + query_BSHDh_value_info, + key_BSHkvDh_value_info, + ] + ) + + source_model_ir = ir.serde.from_proto(source_model) + inferred_model = shape_inference.infer_shapes(source_model_ir) + onnxscript.optimizer.optimize(inferred_model) + + count = fuse_sdpa(inferred_model, debug=True) + self.assertEqual(count, 1) + + count = fuse_gqa(inferred_model, debug=True) + self.assertEqual(count, 1) + + fused_model = ir.serde.to_proto(inferred_model) + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs3 = session.run(None, inputs) + + self.assertEqual(len(outputs3), len(source_model_outputs)) + assert_allclose(outputs3, source_model_outputs) + + +if __name__ == "__main__": + unittest.main()