Skip to content

Update optimize_for_ort call to allow debug and shape_inference modes #2236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions onnxscript/rewriter/_fusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, Sequence, Union

import onnxscript.ir as ir
from onnxscript.ir.passes.common import shape_inference
from onnxscript.rewriter import pattern

Dim = Union[int, ir.SymbolicDim]
Expand All @@ -26,11 +27,18 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str])
def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable:
"""
Apply the given fusion rules to the model and return the number of fusions applied.
If debug is True, enable pattern matching tracer for debugging.

model: The input ONNX model represented as an `ir.Model`.
debug: If debug is True, enable pattern matching tracer for debugging.
apply_shape_inference: If True, apply shape inference after fusions.
"""

def apply_to(model: ir.Model, debug: bool = False) -> int:
def apply_to(
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False
) -> int:
count = rules.apply_to_model(model)
if apply_shape_inference:
shape_inference.infer_shapes(model)
if count == 0 and debug:
tracer = pattern.MatchingTracer()
rules.apply_to_model(model, tracer=tracer)
Expand Down
49 changes: 31 additions & 18 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
# incorporated in our optimizer.
model = shape_inference.infer_shapes(model)
shape_inference.infer_shapes(model)
optimize(model)
return model


def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
def fuse_xformers(model: ir.Model, debug: bool = False) -> tuple[ir.Model, dict[str, int]]:
"""
Apply transformer-specific fusions to the given model.

Args:
model: The input ONNX model represented as an `ir.Model`.
debug: If debug is True, enable pattern matching tracer for debugging.

Returns:
A tuple containing:
Expand All @@ -67,35 +68,42 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
fusion_count = dict()

model = _pre_optimize(model)
fusion_count["erf_gelu"] = fuse_erfgelu(model)
fusion_count["rms_normalization"] = fuse_rms_normalization(model)
fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model)
fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model)
fusion_count["rotary_embedding"] = fuse_rotary_embedding(model)
fusion_count["partial_rotary_embedding"] = fuse_partial_rotary_embedding(model)
fusion_count["cos_sin_cache"] = fuse_cos_sin_cache(model)
fusion_count["sdpa"] = fuse_sdpa(model)

def fuse(func, apply_shape_inference: bool = False):
return func(model, debug=debug, apply_shape_inference=apply_shape_inference)

fusion_count["erf_gelu"] = fuse(fuse_erfgelu)
fusion_count["rms_normalization"] = fuse(fuse_rms_normalization)
fusion_count["skip_layer_normalization"] = fuse(fuse_skip_layer_normalization)
fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization)
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
# Optimize to avoid trying multiple attention-based fusions
fusion_count["mha"] = fuse_mha(model)
fusion_count["mha"] = fuse(fuse_mha)
if fusion_count["mha"] == 0:
# If no MHA fusion was applied, we can try the GQA fusion.
# and avoid trying the attention fusion.
fusion_count["gqa"] = fuse_gqa(model)
fusion_count["packed_qkv_for_gqa"] = fuse_qkv_gqa(model)
fusion_count["gqa"] = fuse(fuse_gqa)
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
fusion_count["attention"] = 0
else:
fusion_count["attention"] = fuse_attention(model)
fusion_count["attention"] = fuse(fuse_attention)
fusion_count["gqa"] = 0
fusion_count["gelu"] = fuse_gelu(model)
fusion_count["bias_gelu"] = fuse_bias_gelu(model)
fusion_count["gelu"] = fuse(fuse_gelu)
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
# Finally: inline any intermediate fusion functions introduced that were not
# consumed by other fusions, and eliminate any remaining unused nodes.
optimize(model)
return model, fusion_count


def optimize_for_ort(
model: ir.Model, config_name: str | None = None
model: ir.Model,
config_name: str | None = None,
*,
debug: bool = False,
) -> tuple[ir.Model, dict[str, int]]:
"""
Optimize the model for ORT backend.
Expand All @@ -108,13 +116,18 @@ def optimize_for_ort(
config_name: The name of the configuration to use for optimization.
Typically it identifies the Execution Provider (EP) to optimize for.
If None, the default configuration will be used.
debug: If debug is True, enable pattern matching tracer for debugging.

Returns:
A tuple containing:
- The optimized `ir.Model` after applying transformer-specific fusions.
- A dictionary with a count of each of the fusions applied.
"""

model, fusion_count = fuse_xformers(model)
model, fusion_count = fuse_xformers(
model,
debug=debug,
)
# Apply the ORT pattern rewrite rules.
rewrite(model, ORT_PATTERN_REWRITE_RULES)
return model, fusion_count
Loading