diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index 166b81d7e2..59bdf87bd0 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -5,6 +5,7 @@ from typing import Callable, Sequence, Union import onnxscript.ir as ir +from onnxscript.ir.passes.common import shape_inference from onnxscript.rewriter import pattern Dim = Union[int, ir.SymbolicDim] @@ -26,11 +27,18 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable: """ Apply the given fusion rules to the model and return the number of fusions applied. - If debug is True, enable pattern matching tracer for debugging. + + model: The input ONNX model represented as an `ir.Model`. + debug: If debug is True, enable pattern matching tracer for debugging. + apply_shape_inference: If True, apply shape inference after fusions. """ - def apply_to(model: ir.Model, debug: bool = False) -> int: + def apply_to( + model: ir.Model, debug: bool = False, apply_shape_inference: bool = False + ) -> int: count = rules.apply_to_model(model) + if apply_shape_inference: + shape_inference.infer_shapes(model) if count == 0 and debug: tracer = pattern.MatchingTracer() rules.apply_to_model(model, tracer=tracer) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 52deb6c1b0..6e23700eea 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -47,17 +47,18 @@ def _pre_optimize(model: ir.Model) -> ir.Model: # TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some # extra shape-propagation and partial-data-propagation rules in ONNX that are not yet # incorporated in our optimizer. - model = shape_inference.infer_shapes(model) + shape_inference.infer_shapes(model) optimize(model) return model -def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]: +def fuse_xformers(model: ir.Model, debug: bool = False) -> tuple[ir.Model, dict[str, int]]: """ Apply transformer-specific fusions to the given model. Args: model: The input ONNX model represented as an `ir.Model`. + debug: If debug is True, enable pattern matching tracer for debugging. Returns: A tuple containing: @@ -67,27 +68,31 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]: fusion_count = dict() model = _pre_optimize(model) - fusion_count["erf_gelu"] = fuse_erfgelu(model) - fusion_count["rms_normalization"] = fuse_rms_normalization(model) - fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model) - fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model) - fusion_count["rotary_embedding"] = fuse_rotary_embedding(model) - fusion_count["partial_rotary_embedding"] = fuse_partial_rotary_embedding(model) - fusion_count["cos_sin_cache"] = fuse_cos_sin_cache(model) - fusion_count["sdpa"] = fuse_sdpa(model) + + def fuse(func, apply_shape_inference: bool = False): + return func(model, debug=debug, apply_shape_inference=apply_shape_inference) + + fusion_count["erf_gelu"] = fuse(fuse_erfgelu) + fusion_count["rms_normalization"] = fuse(fuse_rms_normalization) + fusion_count["skip_layer_normalization"] = fuse(fuse_skip_layer_normalization) + fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization) + fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding) + fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding) + fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache) + fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True) # Optimize to avoid trying multiple attention-based fusions - fusion_count["mha"] = fuse_mha(model) + fusion_count["mha"] = fuse(fuse_mha) if fusion_count["mha"] == 0: # If no MHA fusion was applied, we can try the GQA fusion. # and avoid trying the attention fusion. - fusion_count["gqa"] = fuse_gqa(model) - fusion_count["packed_qkv_for_gqa"] = fuse_qkv_gqa(model) + fusion_count["gqa"] = fuse(fuse_gqa) + fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa) fusion_count["attention"] = 0 else: - fusion_count["attention"] = fuse_attention(model) + fusion_count["attention"] = fuse(fuse_attention) fusion_count["gqa"] = 0 - fusion_count["gelu"] = fuse_gelu(model) - fusion_count["bias_gelu"] = fuse_bias_gelu(model) + fusion_count["gelu"] = fuse(fuse_gelu) + fusion_count["bias_gelu"] = fuse(fuse_bias_gelu) # Finally: inline any intermediate fusion functions introduced that were not # consumed by other fusions, and eliminate any remaining unused nodes. optimize(model) @@ -95,7 +100,10 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]: def optimize_for_ort( - model: ir.Model, config_name: str | None = None + model: ir.Model, + config_name: str | None = None, + *, + debug: bool = False, ) -> tuple[ir.Model, dict[str, int]]: """ Optimize the model for ORT backend. @@ -108,6 +116,7 @@ def optimize_for_ort( config_name: The name of the configuration to use for optimization. Typically it identifies the Execution Provider (EP) to optimize for. If None, the default configuration will be used. + debug: If debug is True, enable pattern matching tracer for debugging. Returns: A tuple containing: @@ -115,6 +124,10 @@ def optimize_for_ort( - A dictionary with a count of each of the fusions applied. """ - model, fusion_count = fuse_xformers(model) + model, fusion_count = fuse_xformers( + model, + debug=debug, + ) + # Apply the ORT pattern rewrite rules. rewrite(model, ORT_PATTERN_REWRITE_RULES) return model, fusion_count