|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import Sequence, Union |
| 6 | + |
| 7 | +import onnxscript.ir as ir |
| 8 | +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern |
| 9 | + |
| 10 | +Dim = Union[int, ir.SymbolicDim] |
| 11 | + |
| 12 | + |
| 13 | +class PackedQKVForGQAFusion(pattern.RewriteRuleClassBase): |
| 14 | + def __init__(self): |
| 15 | + super().__init__("PackedQKVForGQA", remove_nodes=False) |
| 16 | + |
| 17 | + def pattern( |
| 18 | + self, |
| 19 | + op, |
| 20 | + packed_qkv, |
| 21 | + past_key, |
| 22 | + past_value, |
| 23 | + seqlens_k, |
| 24 | + total_seq_length, |
| 25 | + cos, |
| 26 | + sin, |
| 27 | + q_num_heads, |
| 28 | + kv_num_heads, |
| 29 | + interleaved, |
| 30 | + start1, |
| 31 | + end1, |
| 32 | + start2, |
| 33 | + end2, |
| 34 | + start3, |
| 35 | + end3, |
| 36 | + ): |
| 37 | + """Pattern to detect sliced Q, K, V passed to GQA and replace with packed QKV.""" |
| 38 | + |
| 39 | + # Slice packed QKV into query, key, and value |
| 40 | + query_BSD = op.Slice(packed_qkv, start1, end1, [2], [1], _outputs=["query_sliced"]) |
| 41 | + key_BSDkv = op.Slice(packed_qkv, start2, end2, [2], [1], _outputs=["key_sliced"]) |
| 42 | + value_BSDkv = op.Slice(packed_qkv, start3, end3, [2], [1], _outputs=["value_sliced"]) |
| 43 | + |
| 44 | + # Pass sliced Q, K, V to GroupQueryAttention |
| 45 | + return op.GroupQueryAttention( |
| 46 | + query_BSD, |
| 47 | + key_BSDkv, |
| 48 | + value_BSDkv, |
| 49 | + past_key, |
| 50 | + past_value, |
| 51 | + seqlens_k, |
| 52 | + total_seq_length, |
| 53 | + cos, |
| 54 | + sin, |
| 55 | + # mask, # TODO: this is not a valid input for GQA |
| 56 | + num_heads=q_num_heads, |
| 57 | + kv_num_heads=kv_num_heads, |
| 58 | + do_rotary=1, |
| 59 | + rotary_interleaved=interleaved, |
| 60 | + # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap |
| 61 | + _domain="com.microsoft", |
| 62 | + _outputs=3, |
| 63 | + ) |
| 64 | + |
| 65 | + def check( |
| 66 | + self, |
| 67 | + op, |
| 68 | + packed_qkv, |
| 69 | + query_sliced, |
| 70 | + key_sliced, |
| 71 | + value_sliced, |
| 72 | + q_num_heads, |
| 73 | + kv_num_heads, |
| 74 | + start1, |
| 75 | + end1, |
| 76 | + start2, |
| 77 | + end2, |
| 78 | + start3, |
| 79 | + end3, |
| 80 | + **_, |
| 81 | + ): |
| 82 | + check_result = pattern.MatchResult() |
| 83 | + self.bindings: dict[str, Dim] = {} |
| 84 | + |
| 85 | + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: |
| 86 | + return not _fusion_utils._check_shape(self.bindings, val, dims) |
| 87 | + |
| 88 | + # Check that if x is being split into q, k, v correctly |
| 89 | + # based on hidden sizes |
| 90 | + if packed_qkv is None or packed_qkv.shape is None or len(packed_qkv.shape) != 3: |
| 91 | + return check_result.fail("packed_qkv is not a 3D tensor.", packed_qkv) |
| 92 | + hidden_size = packed_qkv.shape[2] |
| 93 | + if not isinstance(hidden_size, int): |
| 94 | + return check_result.fail("Hidden size is not an integer.", packed_qkv) |
| 95 | + q_nh = q_num_heads.value |
| 96 | + kv_nh = kv_num_heads.value |
| 97 | + if not isinstance(q_nh, int) or not isinstance(kv_nh, int): |
| 98 | + return check_result.fail( |
| 99 | + "Could not determine the number of heads for query, key and value.", |
| 100 | + ) |
| 101 | + head_size = hidden_size // (q_nh + (2 * kv_nh)) |
| 102 | + q_hidden_size = head_size * q_nh |
| 103 | + kv_hidden_size = head_size * kv_nh |
| 104 | + if not ( |
| 105 | + _ir_utils.is_singleton_value(start1, 0) |
| 106 | + and _ir_utils.is_singleton_value(end1, q_hidden_size) |
| 107 | + and _ir_utils.is_singleton_value(start2, q_hidden_size) |
| 108 | + and _ir_utils.is_singleton_value(end2, (q_hidden_size + kv_hidden_size)) |
| 109 | + and _ir_utils.is_singleton_value(start3, (q_hidden_size + kv_hidden_size)) |
| 110 | + and _ir_utils.is_singleton_value(end3, lambda x: x >= hidden_size) |
| 111 | + ): |
| 112 | + return check_result.fail( |
| 113 | + "packed_qkv is not being split into q, k, v correctly based on hidden sizes.", |
| 114 | + packed_qkv, |
| 115 | + ) |
| 116 | + |
| 117 | + # Check packed_qkv shape (B, S, D) |
| 118 | + if no_match(packed_qkv, ["B", "S", "D"]): |
| 119 | + return check_result.fail( |
| 120 | + f"Shape mismatch: {packed_qkv} does not match expected dimensions ['B', 'S', 'D']", |
| 121 | + packed_qkv, |
| 122 | + ) |
| 123 | + |
| 124 | + # Check query, key, and value shapes (B, S, Dh) |
| 125 | + if no_match(query_sliced, ["B", "S", "Dq"]): |
| 126 | + return check_result.fail( |
| 127 | + f"Shape mismatch: {query_sliced} does not match expected dimensions ['B', 'S', 'Dq']", |
| 128 | + query_sliced, |
| 129 | + ) |
| 130 | + if no_match(key_sliced, ["B", "S", "Dkv"]): |
| 131 | + return check_result.fail( |
| 132 | + f"Shape mismatch: {key_sliced} does not match expected dimensions ['B', 'S', 'Dkv']", |
| 133 | + key_sliced, |
| 134 | + ) |
| 135 | + if no_match(value_sliced, ["B", "S", "Dkv"]): |
| 136 | + return check_result.fail( |
| 137 | + f"Shape mismatch: {value_sliced} does not match expected dimensions ['B', 'S', 'Dkv']", |
| 138 | + value_sliced, |
| 139 | + ) |
| 140 | + |
| 141 | + # Ensure Dh = Dg + 2*Dkv |
| 142 | + D = self.bindings.get("D") |
| 143 | + Dq = self.bindings.get("Dq") |
| 144 | + Dkv = self.bindings.get("Dkv") |
| 145 | + |
| 146 | + if not isinstance(D, int) or not isinstance(Dq, int) or not isinstance(Dkv, int): |
| 147 | + return check_result.fail( |
| 148 | + "Could not determine the hidden sizes of query, key, and value.", |
| 149 | + ) |
| 150 | + |
| 151 | + if Dq + (2 * Dkv) != D: # type: ignore[operator] |
| 152 | + return check_result.fail( |
| 153 | + f"Hidden size of query, key and value do not add up to hidden size: {D} != {Dq} + (2 * {Dkv})", |
| 154 | + ) |
| 155 | + |
| 156 | + return True |
| 157 | + |
| 158 | + def rewrite( |
| 159 | + self, |
| 160 | + op, |
| 161 | + packed_qkv, |
| 162 | + past_key, |
| 163 | + past_value, |
| 164 | + seqlens_k, |
| 165 | + total_seq_length, |
| 166 | + cos, |
| 167 | + sin, |
| 168 | + q_num_heads, |
| 169 | + kv_num_heads, |
| 170 | + interleaved, |
| 171 | + **_, |
| 172 | + ): |
| 173 | + """Rewrite the sliced Q, K, V into a packed QKV MatMul input for GQA.""" |
| 174 | + |
| 175 | + # Pass packed QKV directly to GroupQueryAttention |
| 176 | + return op.GroupQueryAttention( |
| 177 | + packed_qkv, |
| 178 | + None, |
| 179 | + None, |
| 180 | + past_key, |
| 181 | + past_value, |
| 182 | + seqlens_k, |
| 183 | + total_seq_length, |
| 184 | + cos, |
| 185 | + sin, |
| 186 | + num_heads=q_num_heads, |
| 187 | + kv_num_heads=kv_num_heads, |
| 188 | + do_rotary=1, |
| 189 | + rotary_interleaved=interleaved, |
| 190 | + _domain="com.microsoft", |
| 191 | + _outputs=3, |
| 192 | + ) |
| 193 | + |
| 194 | + |
| 195 | +# Define the fusion rule |
| 196 | +packed_qkv_for_gqa_rule = PackedQKVForGQAFusion.rule() |
| 197 | + |
| 198 | +# Add the rule to the GQA rewrite rule set |
| 199 | +fuse_qkv_gqa_rules = pattern.RewriteRuleSet([packed_qkv_for_gqa_rule]) |
| 200 | + |
| 201 | +# Apply the fusion rules |
| 202 | +fuse_qkv_gqa = _fusion_utils.apply_fusion_rules(fuse_qkv_gqa_rules) |
0 commit comments