Skip to content

Commit a3ce145

Browse files
authored
Ensure rule ordering in MHA fusion (#2334)
MHA fusion rules for patterns without past (key/value cache) should be tried after MHA fusion rules for patterns with past to ensure more optimal fusions. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent a4b91e2 commit a3ce145

File tree

5 files changed

+57
-47
lines changed

5 files changed

+57
-47
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
2121
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
2222
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
23-
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
23+
from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2
2424
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
2525
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
2626
fuse_partial_rotary_embedding,
@@ -87,8 +87,9 @@ def fuse(func, apply_shape_inference: bool = False):
8787
# in the rewrite rule for certain patterns of SDPA.
8888
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
8989
# Optimize to avoid trying multiple attention-based fusions
90-
fusion_count["mha"] = fuse(fuse_mha)
91-
if fusion_count["mha"] == 0:
90+
fusion_count["mha1"] = fuse(fuse_mha1)
91+
fusion_count["mha2"] = fuse(fuse_mha2)
92+
if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0):
9293
# If no MHA fusion was applied, we can try the GQA fusion.
9394
# and avoid trying the attention fusion.
9495
fusion_count["gqa"] = fuse(fuse_gqa)

onnxscript/rewriter/ort_fusions/attention_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def test_whisper_encoder(self):
173173
sdpa_count = xformers.fuse_sdpa(model)
174174
self.assertGreater(sdpa_count, 0)
175175
model = shape_inference.infer_shapes(model)
176-
mha_count = xformers.fuse_mha(model)
176+
mha_count = xformers.fuse_mha1(model)
177+
mha_count += xformers.fuse_mha2(model)
177178
self.assertGreater(mha_count, 0)
178179
fused_mha_bias_count = xformers.fuse_mha_bias(model)
179180
self.assertGreater(fused_mha_bias_count, 0)

onnxscript/rewriter/ort_fusions/fuse_xformers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_fuse_xformers(self):
2727
self.assertEqual(fusion_count["partial_rotary_embedding"], 0)
2828
self.assertEqual(fusion_count["cos_sin_cache"], 2)
2929
self.assertEqual(fusion_count["sdpa"], 1)
30-
self.assertEqual(fusion_count["mha"], 1)
30+
self.assertEqual(fusion_count["mha1"] + fusion_count["mha2"], 1)
3131
self.assertEqual(fusion_count["attention"], 0)
3232
self.assertEqual(fusion_count["gqa"], 0)
3333
self.assertEqual(fusion_count["gelu"], 0)

onnxscript/rewriter/ort_fusions/mha.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -376,45 +376,50 @@ def rewrite(
376376
)
377377

378378

379-
parameter_combinations = [
380-
{
381-
"double_transpose": double_transpose,
382-
"transpose_4d": transpose_4d,
383-
"pre_scale_q": pre_scale_q,
384-
"is_rotary": is_rotary,
385-
"use_mask": use_mask,
386-
"has_past_present": has_past_present,
387-
"is_cross_attention": is_cross_attention,
388-
}
389-
for double_transpose in [False, True]
390-
for transpose_4d in (
391-
[False, True] if double_transpose else [False]
392-
) # Only generate patterns when double_transpose is True
393-
for pre_scale_q in [True, False]
394-
for is_rotary in [False, True]
395-
for use_mask in [False, True]
396-
for is_cross_attention in [False, True]
397-
for has_past_present in ([False] if is_cross_attention else [True, False])
398-
# Skip if both has_past_present and is_cross_attention are True
399-
if not (has_past_present and is_cross_attention)
400-
]
401-
402-
# Dynamically create the rules
403-
mha_rules = pattern.RewriteRuleSet(
404-
[
405-
MultiHeadAttention.rule(
406-
f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose"
407-
f"{'_Twice' if params['double_transpose'] else ''}"
408-
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
409-
f"{'_Rotary' if params['is_rotary'] else ''}"
410-
f"{'_Masked' if params['use_mask'] else ''}"
411-
f"{'_Past' if params['has_past_present'] else ''}"
412-
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",
413-
**params,
414-
)
415-
for params in parameter_combinations
379+
def _make_rule_set(has_past_present: bool):
380+
parameter_combinations = [
381+
{
382+
"double_transpose": double_transpose,
383+
"transpose_4d": transpose_4d,
384+
"pre_scale_q": pre_scale_q,
385+
"is_rotary": is_rotary,
386+
"use_mask": use_mask,
387+
"has_past_present": has_past_present,
388+
"is_cross_attention": is_cross_attention,
389+
}
390+
for double_transpose in [False, True]
391+
for transpose_4d in (
392+
[False, True] if double_transpose else [False]
393+
) # Only generate patterns when double_transpose is True
394+
for pre_scale_q in [True, False]
395+
for is_rotary in [False, True]
396+
for use_mask in [False, True]
397+
for is_cross_attention in ([False] if has_past_present else [False, True])
416398
]
417-
)
418399

400+
# Dynamically create the rules
401+
mha_rules = pattern.RewriteRuleSet(
402+
[
403+
MultiHeadAttention.rule(
404+
f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose"
405+
f"{'_Twice' if params['double_transpose'] else ''}"
406+
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
407+
f"{'_Rotary' if params['is_rotary'] else ''}"
408+
f"{'_Masked' if params['use_mask'] else ''}"
409+
f"{'_Past' if params['has_past_present'] else ''}"
410+
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",
411+
**params,
412+
)
413+
for params in parameter_combinations
414+
]
415+
)
416+
417+
return mha_rules
418+
419+
420+
mha_rules_no_past = _make_rule_set(has_past_present=False)
421+
mha_rules_with_past = _make_rule_set(has_past_present=True)
419422

420-
fuse_mha = _fusion_utils.apply_fusion_rules(mha_rules)
423+
# Try rules with past first, and then rules without past.
424+
fuse_mha1 = _fusion_utils.apply_fusion_rules(mha_rules_with_past)
425+
fuse_mha2 = _fusion_utils.apply_fusion_rules(mha_rules_no_past)

onnxscript/rewriter/ort_fusions/mha_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def test_smollm(self):
3535
# Fuse SDPA and MHA
3636
sdpa_count = xformers.fuse_sdpa(model)
3737
self.assertGreater(sdpa_count, 0)
38-
mha_count = xformers.fuse_mha(model)
38+
mha_count = xformers.fuse_mha1(model)
39+
mha_count += xformers.fuse_mha2(model)
3940
self.assertGreater(mha_count, 0)
4041

4142
if test_with_ort:
@@ -59,7 +60,8 @@ def test_whisper_encoder(self):
5960
sdpa_count = xformers.fuse_sdpa(model)
6061
self.assertGreater(sdpa_count, 0)
6162
model = shape_inference.infer_shapes(model)
62-
mha_count = xformers.fuse_mha(model)
63+
mha_count = xformers.fuse_mha1(model)
64+
mha_count += xformers.fuse_mha2(model)
6365
self.assertGreater(mha_count, 0)
6466
onnxscript.optimizer.optimize(model)
6567

@@ -84,7 +86,8 @@ def test_whisper_decoder(self):
8486
sdpa_count = xformers.fuse_sdpa(model)
8587
self.assertGreater(sdpa_count, 0)
8688
model = shape_inference.infer_shapes(model)
87-
mha_count = xformers.fuse_mha(model)
89+
mha_count = xformers.fuse_mha1(model)
90+
mha_count += xformers.fuse_mha2(model)
8891
self.assertGreater(mha_count, 0)
8992
onnxscript.optimizer.optimize(model)
9093

0 commit comments

Comments
 (0)