Skip to content

Commit 4340a6c

Browse files
Optimization to avoid trying multiple attention-based fusions (#2168)
1 parent f93eb58 commit 4340a6c

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from onnxscript.rewriter.ort_fusions.attention import fuse_attention
1616
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
1717
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
18+
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
1819
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
1920
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
2021
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
@@ -70,8 +71,16 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
7071
fusion_count["partial_rotary_embedding"] = fuse_partial_rotary_embedding(model)
7172
fusion_count["cos_sin_cache"] = fuse_cos_sin_cache(model)
7273
fusion_count["sdpa"] = fuse_sdpa(model)
74+
# Optimize to avoid trying multiple attention-based fusions
7375
fusion_count["mha"] = fuse_mha(model)
74-
fusion_count["attention"] = fuse_attention(model)
76+
if fusion_count["mha"] == 0:
77+
# If no MHA fusion was applied, we can try the GQA fusion.
78+
# and avoid trying the attention fusion.
79+
fusion_count["gqa"] = fuse_gqa(model)
80+
fusion_count["attention"] = 0
81+
else:
82+
fusion_count["attention"] = fuse_attention(model)
83+
fusion_count["gqa"] = 0
7584
fusion_count["gelu"] = fuse_gelu(model)
7685
# Finally: inline any intermediate fusion functions introduced that were not
7786
# consumed by other fusions, and eliminate any remaining unused nodes.

onnxscript/rewriter/ort_fusions/fuse_xformers_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def test_fuse_xformers(self):
2929
self.assertEqual(fusion_count["sdpa"], 1)
3030
self.assertEqual(fusion_count["mha"], 0)
3131
self.assertEqual(fusion_count["attention"], 0)
32+
self.assertEqual(fusion_count["gqa"], 0)
3233
self.assertEqual(fusion_count["gelu"], 0)
3334

3435
new_outputs = ort_run("optimized", model, inputs)

0 commit comments

Comments
 (0)