Skip to content

Commit e71c889

Browse files
authored
Move gemm_to_matmul_add rule to ort fusion rules (#2398)
Stop decomposing gemm to matmul add by default because it is a more compact representation. Move the ort fusion rules so it keeps functioning for ort. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 483599e commit e71c889

File tree

3 files changed

+6
-26
lines changed

3 files changed

+6
-26
lines changed

docs/tutorial/optimizer/optimize.md

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ onnxscript.optimizer.optimize(model)
1515
```
1616

1717
### optimize API
18+
1819
The `onnxscript.optimizer.optimize` call takes in several optional parameters that allows the caller to further fine-tune the process of optimization.
1920

2021
```{eval-rst}
@@ -24,30 +25,12 @@ The `onnxscript.optimizer.optimize` call takes in several optional parameters th
2425

2526
## Description of optimizations applied by `onnxscript.optimizer.optimize`
2627

27-
:::{table}
28-
:widths: auto
29-
:align: center
30-
31-
| Optimization 'onnxscript.optimizer.` + .. | Description |
32-
| - | - |
28+
| Optimization | Description |
29+
|-------------|-------------|
3330
| **Constant folding** <br>`constant_folding.fold_constants` | Applies constant folding optimization to the model. |
3431
| **Constant propagation** <br>`constant_folding.fold_constants` | Applies constant propagation optimization to the model. Applied as part of the constant folding optimization. |
3532
| **Sequence simplification** <br>`constant_folding.fold_constants` | Simplifies Sequence based ops (SequenceConstruct, ConcatFromSequence) present in the model. Applied as part of the constant folding optimization. |
3633
| **Remove unused nodes** <br>`remove_unused.remove_unused_nodes` | Removes unused nodes from the model. |
3734
| **Remove unused functions** <br>`remove_unused_function.remove_unused_functions` | Removes unused function protos from the model. |
3835
| **Inline functions with unused outputs** <br>`simple_function_folding.inline_functions_with_unused_outputs` | Inlines function nodes that have unused outputs. |
3936
| **Inline simple functions** <br>`simple_function_folding.inline_simple_functions` | Inlines simple functions based on a node count threshold. |
40-
:::
41-
42-
## List of pattern rewrite rules applied by `onnxscript.optimizer.optimize`
43-
44-
```{eval-rst}
45-
.. autosummary::
46-
:nosignatures:
47-
48-
onnxscript.rewriter.broadcast_to_matmul
49-
onnxscript.rewriter.cast_constant_of_shape
50-
onnxscript.rewriter.gemm_to_matmul_add
51-
onnxscript.rewriter.no_op
52-
53-
```

onnxscript/rewriter/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
broadcast_to_matmul,
2020
cast_constant_of_shape,
2121
collapse_slices,
22-
gemm_to_matmul_add,
2322
no_op,
2423
pattern,
2524
)
@@ -28,7 +27,6 @@
2827
_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (
2928
*no_op.rules.rules, # TODO: merge this rule into constant folding?
3029
*broadcast_to_matmul.rules.rules,
31-
gemm_to_matmul_add.rule, # type: ignore[has-type]
3230
*cast_constant_of_shape.rules.rules,
3331
*collapse_slices.rules.rules,
3432
*basic_rules.basic_optimization_rules().rules,

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets
88
import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization
99
from onnxscript.optimizer import optimize
10-
from onnxscript.rewriter import rewrite
10+
from onnxscript.rewriter import gemm_to_matmul_add, rewrite
1111
from onnxscript.rewriter.ort_fusions import (
12-
# group_normalization_merge_silu,
1312
instance_to_group_normalization,
1413
softmax,
1514
)
@@ -38,7 +37,7 @@
3837
*instance_to_group_normalization.rules.rules,
3938
# NOTE: group normalization merge silu should be applied after instance to group normalization
4039
# *group_normalization_merge_silu.rules.rules,
41-
*fused_matmul_rule_sets.fused_matmul_rule_sets().rules,
40+
*fused_matmul_rule_sets.fused_matmul_rule_sets(),
4241
]
4342

4443

@@ -130,7 +129,7 @@ def optimize_for_ort(
130129
- The optimized `ir.Model` after applying transformer-specific fusions.
131130
- A dictionary with a count of each of the fusions applied.
132131
"""
133-
132+
rewrite(model, [gemm_to_matmul_add.rule])
134133
model, fusion_count = fuse_xformers(
135134
model,
136135
debug=debug,

0 commit comments

Comments
 (0)