Skip to content

Commit 949bc24

Browse files
authored
Fusion extensions to improve GQA fusion (#2374)
Various extensions to improve GQA fusion. * Move key-transpose into SDPA fusion and clean it up * Extend cos-sin-cache fusion to handle a new pattern * Reorder GQA and MHA rules * Introduce MaskedGQA, since many uses in practice generated GQA with a mask * MaskedGQA is subsequently simplified to ORT's GQA if the mask can be verified to be causal. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 321cb41 commit 949bc24

File tree

9 files changed

+179
-90
lines changed

9 files changed

+179
-90
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,18 @@ def fuse(func, **kwargs):
8686
# We apply shape inference after the SDPA fusion as new nodes are added
8787
# in the rewrite rule for certain patterns of SDPA.
8888
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
89-
# Optimize to avoid trying multiple attention-based fusions
89+
90+
fusion_count["gqa"] = fuse(fuse_gqa)
91+
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
92+
9093
fusion_count["mha1"] = fuse(fuse_mha1)
9194
fusion_count["mha2"] = fuse(fuse_mha2)
9295
if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0):
93-
# If no MHA fusion was applied, we can try the GQA fusion.
94-
# and avoid trying the attention fusion.
95-
fusion_count["gqa"] = fuse(fuse_gqa)
96-
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
9796
fusion_count["mha_bias"] = 0
9897
fusion_count["attention"] = 0
9998
else:
10099
fusion_count["mha_bias"] = fuse(fuse_mha_bias)
101100
fusion_count["attention"] = fuse(fuse_attention)
102-
fusion_count["gqa"] = 0
103101
fusion_count["gelu"] = fuse(fuse_gelu)
104102
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
105103
# Finally: inline any intermediate fusion functions introduced that were not

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,16 @@ def cleanup(self):
106106
self._inv_freq_cos_sin_cache.clear()
107107

108108
def pattern(
109-
self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, extra_dims
109+
self,
110+
op,
111+
x,
112+
inv_freq,
113+
position_ids,
114+
interleaved,
115+
num_heads,
116+
freqs,
117+
dtype,
118+
extra_dims,
110119
):
111120
if not self._const_freqs:
112121
# Compute freqs from inv_freq and position_ids. In the _const_freqs case,
@@ -121,6 +130,13 @@ def pattern(
121130
# if self._reshape:
122131
# position_ids_expanded = op.Expand(position_ids_expanded, _allow_other_inputs=True)
123132
# position_ids_expanded = op.Reshape(position_ids_expanded, _allow_other_inputs=True)
133+
# inv_freq may optionally be expanded to shape [B, E, 1]
134+
inv_freq = pattern.OrValue(
135+
[
136+
op.Expand(inv_freq, pattern.ANY_VALUE, _outputs=["expanded_inv_freq"]),
137+
inv_freq,
138+
]
139+
)
124140
freqs = op.MatMul(inv_freq, position_ids_expanded) # [B, E, S]
125141
# if self._reshape:
126142
# freqs = op.Reshape(freqs, freqs_3d_shape) # redundant reshape
@@ -140,11 +156,11 @@ def pattern(
140156
sin_4d,
141157
interleaved=interleaved,
142158
num_heads=num_heads,
143-
_domain="ai.onnxruntime.fusion",
159+
_domain="ai.onnxruntime._fusion",
144160
)
145161

146162
def check(
147-
self, context, inv_freq, position_ids, freqs, extra_dims, **_
163+
self, context, inv_freq, position_ids, freqs, extra_dims, expanded_inv_freq=None, **_
148164
) -> pattern.MatchResult: # type: ignore[name-defined]
149165
check_result = pattern.MatchResult()
150166
# TODO(rama): handle redundant reshape/expand
@@ -164,6 +180,10 @@ def check(
164180
if not _ir_utils.has_rank(inv_freq, 3):
165181
return check_result.fail("inv_freq is not 3D.", inv_freq)
166182
inv_freq_shape = inv_freq.shape
183+
if expanded_inv_freq is not None:
184+
if not _ir_utils.has_rank(expanded_inv_freq, 3):
185+
return check_result.fail("expanded_inv_freq is not 3D.", expanded_inv_freq)
186+
# TODO: check expanded_inv_freq shape
167187
if inv_freq.const_value is None: # TODO: should this be inv_freq_shape?
168188
return check_result.fail("inv_freq is not a constant.", inv_freq)
169189
if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1:

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 85 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,11 @@ def pattern(
7878
value_BSDkv,
7979
past_key,
8080
past_value,
81-
input_ids,
82-
past_seq_length,
83-
total_seq_length,
81+
position_ids_q,
82+
position_ids_k,
8483
cos,
8584
sin,
86-
some_kv_cache,
87-
shape_B111,
85+
mask,
8886
):
8987
# Reshape query from (B, S, D) to (B, S, H, D/H)
9088
query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"])
@@ -101,10 +99,6 @@ def pattern(
10199
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
102100
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])
103101

104-
position_ids = op.Range(past_seq_length, total_seq_length, 1)
105-
position_ids_q = op.Unsqueeze(position_ids, [0])
106-
position_ids_k = op.Unsqueeze(position_ids, [0])
107-
108102
query_BHSDh_rope = op.RotaryEmbedding(
109103
query_BHSDh,
110104
position_ids_q,
@@ -141,15 +135,13 @@ def pattern(
141135
value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"]
142136
)
143137

144-
mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111)
145-
146-
key_seq_BHDhT = op.Transpose(key_seq_BHTDh, perm=[0, 1, 3, 2])
147138
attention_BHSDh = op.SDPA(
148139
query_BHSDh_rope,
149-
key_seq_BHDhT,
140+
key_seq_BHTDh,
150141
value_seq_BHTDh,
151142
mask,
152-
_domain="ai.onnxruntime.fusion",
143+
key_format="BHSd",
144+
_domain="ai.onnxruntime._fusion",
153145
)
154146

155147
# Transpose attention back to (B, S, H, D/H)
@@ -209,8 +201,8 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
209201
# Rotary embedding attributes
210202
query_rotary_attributes = query_BHSDh_rope.producer().attributes
211203
key_rotary_attributes = key_BHkvSDh_rope.producer().attributes
212-
query_interleaved = query_rotary_attributes.get("interleaved", 0)
213-
key_interleaved = key_rotary_attributes.get("interleaved", 0)
204+
query_interleaved = query_rotary_attributes.get_int("interleaved", 0)
205+
key_interleaved = key_rotary_attributes.get_int("interleaved", 0)
214206
if query_interleaved != key_interleaved:
215207
return pattern.MatchResult().fail(
216208
"Rotary embedding interleaved attribute mismatch",
@@ -228,42 +220,104 @@ def rewrite(
228220
value_BSDkv,
229221
past_key,
230222
past_value,
231-
total_seq_length,
223+
position_ids_q,
224+
position_ids_k,
232225
cos,
233226
sin,
227+
mask,
234228
**_,
235229
):
236-
total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32)
237-
one_0D = op.Constant(value_int=1)
238-
one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32)
239-
seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32)
240-
zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1])
241-
seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D)
242-
243-
return op.GroupQueryAttention(
230+
return op.GQA(
231+
mask,
232+
position_ids_k,
233+
position_ids_q,
244234
query_BSD,
245235
key_BSDkv,
246236
value_BSDkv,
247237
past_key,
248238
past_value,
249-
seqlens_k,
250-
total_seq_length_int32,
239+
None, # seqlens_k,
240+
None, # total_seq_length_int32,
251241
cos,
252242
sin,
253-
# mask, # TODO: this is not a valid input for GQA
254243
num_heads=self.num_heads,
255244
kv_num_heads=self.kv_num_heads,
256245
do_rotary=1,
257246
rotary_interleaved=self._interleaved,
258247
# skipped optional attributes: local_window_size, scale, smooth_softmax, softcap
259-
_domain="com.microsoft",
248+
_domain="ai.onnxruntime._fusion",
260249
_outputs=3,
261250
)
262251

263252

264-
_rule1 = GroupQueryAttention.rule()
253+
class GQACausalMask(pattern.RewriteRuleClassBase):
254+
def __init__(self):
255+
super().__init__("GQACausalMask", remove_nodes=False)
256+
257+
def pattern(
258+
self,
259+
op,
260+
mask,
261+
input_ids,
262+
some_kv_cache,
263+
shape_B111,
264+
past_seq_length,
265+
total_seq_length,
266+
):
267+
mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111)
268+
position_ids = op.Range(past_seq_length, total_seq_length, 1)
269+
position_ids_q = op.Unsqueeze(position_ids, [0])
270+
position_ids_k = op.Unsqueeze(position_ids, [0])
271+
return op.GQA(
272+
mask,
273+
position_ids_k,
274+
position_ids_q,
275+
_allow_other_inputs=True,
276+
_domain="ai.onnxruntime._fusion",
277+
_outputs=["attn_output", "key_seq", "value_seq"],
278+
)
279+
280+
def rewrite(
281+
self,
282+
op,
283+
total_seq_length,
284+
attn_output,
285+
**_,
286+
):
287+
# Construct total_seq_length_int32 and seqlens_k
288+
total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32)
289+
one_0D = op.Constant(value_int=1)
290+
one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32)
291+
seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32)
292+
zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1])
293+
seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D)
294+
295+
gqa_node = attn_output.producer()
296+
assert len(gqa_node.inputs) == 12, (
297+
f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}"
298+
)
299+
query, key, value, past_key, past_value = gqa_node.inputs[3:8]
300+
cos, sin = gqa_node.inputs[10:12]
301+
updated_inputs = [
302+
query,
303+
key,
304+
value,
305+
past_key,
306+
past_value,
307+
seqlens_k,
308+
total_seq_length_int32,
309+
cos,
310+
sin,
311+
]
312+
attributes = gqa_node.attributes
313+
return op.GroupQueryAttention(
314+
*updated_inputs, **attributes, _domain="com.microsoft", _outputs=3
315+
)
316+
265317

266-
gqa_rules = pattern.RewriteRuleSet([_rule1])
318+
_basic_gqa_rule = GroupQueryAttention.rule()
319+
_gqa_causal_mask_rule = GQACausalMask.rule()
267320

321+
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule])
268322

269323
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)

onnxscript/rewriter/ort_fusions/gqa_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ def test_fusion(self):
307307
onnx.TensorProto.FLOAT,
308308
["B", self.seqlen, self.num_heads, self.head_size],
309309
)
310+
key_BHSDh_value_info = onnx.helper.make_tensor_value_info(
311+
"key_BHSDh",
312+
onnx.TensorProto.FLOAT,
313+
["B", self.num_heads, self.total_seqlen, self.head_size],
314+
)
310315
key_BSHkvDh_value_info = onnx.helper.make_tensor_value_info(
311316
"key_BSHkvDh",
312317
onnx.TensorProto.FLOAT,
@@ -327,6 +332,7 @@ def test_fusion(self):
327332
query_BHSDh_rope_value_info,
328333
key_BHkvSDh_rope_value_info,
329334
query_BSHDh_value_info,
335+
key_BHSDh_value_info,
330336
key_BSHkvDh_value_info,
331337
key_transposed_value_info,
332338
value_BHSDh_value_info,
@@ -338,10 +344,10 @@ def test_fusion(self):
338344
onnxscript.optimizer.optimize(inferred_model)
339345

340346
count = fuse_sdpa(inferred_model, debug=True)
341-
self.assertEqual(count, 1)
347+
self.assertGreater(count, 0)
342348

343349
count = fuse_gqa(inferred_model, debug=True)
344-
self.assertEqual(count, 1)
350+
self.assertGreater(count, 0)
345351

346352
fused_model = ir.serde.to_proto(inferred_model)
347353
session = ort.InferenceSession(

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,14 @@ def pattern(
7979

8080
if not self._is_cross_attention:
8181
# Reshape from (B, S, D) to (B, S, H, D/H)
82-
key_BSHDh = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"])
83-
84-
# Possible Transpose patterns for key:
85-
# This scenario optimizes the need for a double transpose
86-
# 1. (B, S, H, D/H) -> (B, H, D/H, S)
87-
# Patterns with double transpose of key
88-
# Double transpose should handle this optimization
89-
# 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S)
90-
# Patterns where key is reshaped to 3D, transposed and reshaped back to 4D
91-
# 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S)
92-
key_BHSDh = op.Transpose(key_BSHDh, perm=key_perm)
82+
key = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"])
9383

9484
# Reshape from (B, S, D) to (B, S, H, D/H)
9585
value_BSHDh = op.Reshape(value, pattern.ANY_VALUE, _outputs=["value_BSHDh"])
9686
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
9787
value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3])
9888
else:
9989
# For cross-attention, key and value are not reshaped
100-
key_BHSDh = key
10190
value_BHSDh = value
10291

10392
if self._is_rotary:
@@ -118,14 +107,14 @@ def pattern(
118107
)
119108
if not self._is_cross_attention:
120109
key_BHSDh_emb = op.RotaryEmbedding(
121-
key_BHSDh, position_ids_k, cos, sin, _domain="com.microsoft"
110+
key, position_ids_k, cos, sin, _domain="com.microsoft"
122111
)
123112
else:
124-
key_BHSDh_emb = key_BHSDh
113+
key_BHSDh_emb = key
125114
else:
126115
# If rotary embedding is not used, we fuse with positional_embeddings
127116
query_BHSDh_emb = query_BHSDh
128-
key_BHSDh_emb = key_BHSDh
117+
key_BHSDh_emb = key
129118

130119
# Concatenate past_key cache and current key, and transpose to enable
131120
# dot-product attention computation.
@@ -144,35 +133,21 @@ def pattern(
144133
key_seq_to_sdpa = key_seq
145134
value_seq_to_sdpa = value_seq
146135

147-
# Transpose last two axes of key_seq to compute dot-product via matmul.
148-
if self._double_transpose:
149-
if self._transpose_4d:
150-
key_seq_to_sdpa = op.Transpose(key_seq_to_sdpa, perm=[0, 1, 3, 2])
151-
else:
152-
# Transpose after converting to 3D
153-
key_seq_BH_Skv_Dh = op.Reshape(
154-
key_seq_to_sdpa, pattern.ANY_VALUE, _outputs=["key_seq_BH_Skv_Dh"]
155-
)
156-
key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1])
157-
key_seq_to_sdpa = op.Reshape(
158-
key_seq_BH_Dh_Skv, pattern.ANY_VALUE, _outputs=["key_seq_B_H_Dh_Skv"]
159-
)
160-
161136
# TODO: Remove use_mask once SDPA op is usable
162137
if self._use_mask:
163138
sdpa = op.SDPA(
164139
query_BHSDh_emb,
165140
key_seq_to_sdpa,
166141
value_seq_to_sdpa,
167142
mask,
168-
_domain="ai.onnxruntime.fusion",
143+
_domain="ai.onnxruntime._fusion",
169144
)
170145
else:
171146
sdpa = op.SDPA(
172147
query_BHSDh_emb,
173148
key_seq_to_sdpa,
174149
value_seq_to_sdpa,
175-
_domain="ai.onnxruntime.fusion",
150+
_domain="ai.onnxruntime._fusion",
176151
)
177152

178153
# Transpose attention back to (B, S, H, D/H)

onnxscript/rewriter/ort_fusions/mha_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_whisper_encoder(self):
5757
original_outputs = ort_run("original", model, inputs)
5858

5959
# Fuse SDPA and MHA
60-
sdpa_count = xformers.fuse_sdpa(model)
60+
sdpa_count = xformers.fuse_sdpa(model, debug=True)
6161
self.assertGreater(sdpa_count, 0)
6262
model = common_passes.ShapeInferencePass()(model).model
6363
mha_count = xformers.fuse_mha1(model)

onnxscript/rewriter/ort_fusions/rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult:
5656
def rewrite(self, op, x, cos, sin, **_):
5757
num_heads = x.shape[1]
5858
return op.RotaryEmbedding(
59-
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion"
59+
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime._fusion"
6060
)
6161

6262

0 commit comments

Comments
 (0)