Skip to content

Commit a1c9380

Browse files
authored
Cleanup ort transformer fusions (#2115)
Cleanup ort transformer-fusions.
1 parent b6a0a81 commit a1c9380

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import onnxscript.ir as ir
6+
from onnxscript.ir.passes.common import shape_inference
67
from onnxscript.optimizer import optimize, remove_unused_nodes
78
from onnxscript.rewriter import rewrite
89
from onnxscript.rewriter.ort_fusions import (
@@ -12,9 +13,13 @@
1213
softmax,
1314
)
1415
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
16+
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
1517
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
1618
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
17-
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
19+
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
20+
fuse_partial_rotary_embedding,
21+
fuse_rotary_embedding,
22+
)
1823
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
1924
from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_normalization
2025

@@ -27,14 +32,29 @@
2732
]
2833

2934

30-
def fuse_xformers(model: ir.Model) -> None:
35+
# Preliminary optimizations before applying the transformer fusions.
36+
# TODO: There are some potential redundancies below. Can be targeted for optimization
37+
# once we have robust fusion.
38+
def _pre_optimize(model: ir.Model) -> ir.Model:
39+
optimize(model)
40+
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
41+
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
42+
# incorporated in our optimizer.
43+
model = shape_inference.infer_shapes(model)
3144
optimize(model)
45+
return model
46+
47+
48+
def fuse_xformers(model: ir.Model) -> None:
49+
model = _pre_optimize(model)
3250
fuse_rms_normalization(model)
3351
fuse_normalization(model)
3452
fuse_rotary_embedding(model)
53+
fuse_partial_rotary_embedding(model)
3554
fuse_cos_sin_cache(model)
3655
fuse_sdpa(model)
3756
fuse_mha(model)
57+
fuse_gelu(model)
3858
remove_unused_nodes(model)
3959

4060

0 commit comments

Comments
 (0)