|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | import onnxscript.ir as ir
|
| 6 | +from onnxscript.ir.passes.common import shape_inference |
6 | 7 | from onnxscript.optimizer import optimize, remove_unused_nodes
|
7 | 8 | from onnxscript.rewriter import rewrite
|
8 | 9 | from onnxscript.rewriter.ort_fusions import (
|
|
12 | 13 | softmax,
|
13 | 14 | )
|
14 | 15 | from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
|
| 16 | +from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu |
15 | 17 | from onnxscript.rewriter.ort_fusions.mha import fuse_mha
|
16 | 18 | from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
|
17 |
| -from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding |
| 19 | +from onnxscript.rewriter.ort_fusions.rotary_embedding import ( |
| 20 | + fuse_partial_rotary_embedding, |
| 21 | + fuse_rotary_embedding, |
| 22 | +) |
18 | 23 | from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
|
19 | 24 | from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_normalization
|
20 | 25 |
|
|
27 | 32 | ]
|
28 | 33 |
|
29 | 34 |
|
30 |
| -def fuse_xformers(model: ir.Model) -> None: |
| 35 | +# Preliminary optimizations before applying the transformer fusions. |
| 36 | +# TODO: There are some potential redundancies below. Can be targeted for optimization |
| 37 | +# once we have robust fusion. |
| 38 | +def _pre_optimize(model: ir.Model) -> ir.Model: |
| 39 | + optimize(model) |
| 40 | + # TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some |
| 41 | + # extra shape-propagation and partial-data-propagation rules in ONNX that are not yet |
| 42 | + # incorporated in our optimizer. |
| 43 | + model = shape_inference.infer_shapes(model) |
31 | 44 | optimize(model)
|
| 45 | + return model |
| 46 | + |
| 47 | + |
| 48 | +def fuse_xformers(model: ir.Model) -> None: |
| 49 | + model = _pre_optimize(model) |
32 | 50 | fuse_rms_normalization(model)
|
33 | 51 | fuse_normalization(model)
|
34 | 52 | fuse_rotary_embedding(model)
|
| 53 | + fuse_partial_rotary_embedding(model) |
35 | 54 | fuse_cos_sin_cache(model)
|
36 | 55 | fuse_sdpa(model)
|
37 | 56 | fuse_mha(model)
|
| 57 | + fuse_gelu(model) |
38 | 58 | remove_unused_nodes(model)
|
39 | 59 |
|
40 | 60 |
|
|
0 commit comments