|
15 | 15 | from onnxscript.rewriter.ort_fusions.attention import fuse_attention
|
16 | 16 | from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
|
17 | 17 | from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
|
| 18 | +from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa |
18 | 19 | from onnxscript.rewriter.ort_fusions.mha import fuse_mha
|
19 | 20 | from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
|
20 | 21 | from onnxscript.rewriter.ort_fusions.rotary_embedding import (
|
@@ -70,8 +71,16 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
|
70 | 71 | fusion_count["partial_rotary_embedding"] = fuse_partial_rotary_embedding(model)
|
71 | 72 | fusion_count["cos_sin_cache"] = fuse_cos_sin_cache(model)
|
72 | 73 | fusion_count["sdpa"] = fuse_sdpa(model)
|
| 74 | + # Optimize to avoid trying multiple attention-based fusions |
73 | 75 | fusion_count["mha"] = fuse_mha(model)
|
74 |
| - fusion_count["attention"] = fuse_attention(model) |
| 76 | + if fusion_count["mha"] == 0: |
| 77 | + # If no MHA fusion was applied, we can try the GQA fusion. |
| 78 | + # and avoid trying the attention fusion. |
| 79 | + fusion_count["gqa"] = fuse_gqa(model) |
| 80 | + fusion_count["attention"] = 0 |
| 81 | + else: |
| 82 | + fusion_count["attention"] = fuse_attention(model) |
| 83 | + fusion_count["gqa"] = 0 |
75 | 84 | fusion_count["gelu"] = fuse_gelu(model)
|
76 | 85 | # Finally: inline any intermediate fusion functions introduced that were not
|
77 | 86 | # consumed by other fusions, and eliminate any remaining unused nodes.
|
|
0 commit comments