Skip to content

Commit 8550064

Browse files
Update optimize_for_ort call to allow debug and shape_inference modes (#2236)
- debug=True, can be called for all the ort-fusion rules - apply_shape_inference=True, can be called, if we want to apply shape_inference after each fusion rule is applied
1 parent 349946c commit 8550064

File tree

2 files changed

+41
-20
lines changed

2 files changed

+41
-20
lines changed

onnxscript/rewriter/_fusion_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Callable, Sequence, Union
66

77
import onnxscript.ir as ir
8+
from onnxscript.ir.passes.common import shape_inference
89
from onnxscript.rewriter import pattern
910

1011
Dim = Union[int, ir.SymbolicDim]
@@ -26,11 +27,18 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str])
2627
def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable:
2728
"""
2829
Apply the given fusion rules to the model and return the number of fusions applied.
29-
If debug is True, enable pattern matching tracer for debugging.
30+
31+
model: The input ONNX model represented as an `ir.Model`.
32+
debug: If debug is True, enable pattern matching tracer for debugging.
33+
apply_shape_inference: If True, apply shape inference after fusions.
3034
"""
3135

32-
def apply_to(model: ir.Model, debug: bool = False) -> int:
36+
def apply_to(
37+
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False
38+
) -> int:
3339
count = rules.apply_to_model(model)
40+
if apply_shape_inference:
41+
shape_inference.infer_shapes(model)
3442
if count == 0 and debug:
3543
tracer = pattern.MatchingTracer()
3644
rules.apply_to_model(model, tracer=tracer)

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,18 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
4747
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
4848
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
4949
# incorporated in our optimizer.
50-
model = shape_inference.infer_shapes(model)
50+
shape_inference.infer_shapes(model)
5151
optimize(model)
5252
return model
5353

5454

55-
def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
55+
def fuse_xformers(model: ir.Model, debug: bool = False) -> tuple[ir.Model, dict[str, int]]:
5656
"""
5757
Apply transformer-specific fusions to the given model.
5858
5959
Args:
6060
model: The input ONNX model represented as an `ir.Model`.
61+
debug: If debug is True, enable pattern matching tracer for debugging.
6162
6263
Returns:
6364
A tuple containing:
@@ -67,35 +68,42 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
6768
fusion_count = dict()
6869

6970
model = _pre_optimize(model)
70-
fusion_count["erf_gelu"] = fuse_erfgelu(model)
71-
fusion_count["rms_normalization"] = fuse_rms_normalization(model)
72-
fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model)
73-
fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model)
74-
fusion_count["rotary_embedding"] = fuse_rotary_embedding(model)
75-
fusion_count["partial_rotary_embedding"] = fuse_partial_rotary_embedding(model)
76-
fusion_count["cos_sin_cache"] = fuse_cos_sin_cache(model)
77-
fusion_count["sdpa"] = fuse_sdpa(model)
71+
72+
def fuse(func, apply_shape_inference: bool = False):
73+
return func(model, debug=debug, apply_shape_inference=apply_shape_inference)
74+
75+
fusion_count["erf_gelu"] = fuse(fuse_erfgelu)
76+
fusion_count["rms_normalization"] = fuse(fuse_rms_normalization)
77+
fusion_count["skip_layer_normalization"] = fuse(fuse_skip_layer_normalization)
78+
fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization)
79+
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
80+
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
81+
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
82+
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
7883
# Optimize to avoid trying multiple attention-based fusions
79-
fusion_count["mha"] = fuse_mha(model)
84+
fusion_count["mha"] = fuse(fuse_mha)
8085
if fusion_count["mha"] == 0:
8186
# If no MHA fusion was applied, we can try the GQA fusion.
8287
# and avoid trying the attention fusion.
83-
fusion_count["gqa"] = fuse_gqa(model)
84-
fusion_count["packed_qkv_for_gqa"] = fuse_qkv_gqa(model)
88+
fusion_count["gqa"] = fuse(fuse_gqa)
89+
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
8590
fusion_count["attention"] = 0
8691
else:
87-
fusion_count["attention"] = fuse_attention(model)
92+
fusion_count["attention"] = fuse(fuse_attention)
8893
fusion_count["gqa"] = 0
89-
fusion_count["gelu"] = fuse_gelu(model)
90-
fusion_count["bias_gelu"] = fuse_bias_gelu(model)
94+
fusion_count["gelu"] = fuse(fuse_gelu)
95+
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
9196
# Finally: inline any intermediate fusion functions introduced that were not
9297
# consumed by other fusions, and eliminate any remaining unused nodes.
9398
optimize(model)
9499
return model, fusion_count
95100

96101

97102
def optimize_for_ort(
98-
model: ir.Model, config_name: str | None = None
103+
model: ir.Model,
104+
config_name: str | None = None,
105+
*,
106+
debug: bool = False,
99107
) -> tuple[ir.Model, dict[str, int]]:
100108
"""
101109
Optimize the model for ORT backend.
@@ -108,13 +116,18 @@ def optimize_for_ort(
108116
config_name: The name of the configuration to use for optimization.
109117
Typically it identifies the Execution Provider (EP) to optimize for.
110118
If None, the default configuration will be used.
119+
debug: If debug is True, enable pattern matching tracer for debugging.
111120
112121
Returns:
113122
A tuple containing:
114123
- The optimized `ir.Model` after applying transformer-specific fusions.
115124
- A dictionary with a count of each of the fusions applied.
116125
"""
117126

118-
model, fusion_count = fuse_xformers(model)
127+
model, fusion_count = fuse_xformers(
128+
model,
129+
debug=debug,
130+
)
131+
# Apply the ORT pattern rewrite rules.
119132
rewrite(model, ORT_PATTERN_REWRITE_RULES)
120133
return model, fusion_count

0 commit comments

Comments
 (0)