Skip to content

Commit 605e06e

Browse files
Add fusion rules (Whisper optimizations) (#2221)
Add fusion rules to support the optimization of Whisper models. Fusions added: - Basic Fusions: * additional pattern for erfgelu [moved to #2222] - SkipLayerNorm: * #2259 * Fusion patterns where skip_sum is also an output * Bias + SkipLayerNorm -> SkipLayerNorm (with bias) [moved to #2222] - BiasGelu Fusion [moved to #2222] - SDPA: * Support for pattern where only q is pre-scaled - MHA: * Patterns with/without past/present keys/values * Patterns with non-rotary embeddings * Patterns with/without mask * Patterns with cross-attention (only for past key/value patterns) - MHA Bias Fusion: * Bias was offloaded to Attention fusion previously, this fusion fixes that - Attention: * Patterns where Q, K and V do not come from slicing TODO: - [x] Fix SDPA singular prescale case, due to lost shape information - [x] - Enable check conditions when #2210 is merged - [x] - Improve/Rewrite whisper model test case to be similar to that of smollm (for eg) - [x] - Fix failing test cases to account for new patterns - [x] - Add isolated test cases for new fusions like BiasGelu, SkipLayerNorm etc
1 parent 33f31ca commit 605e06e

File tree

11 files changed

+1357
-283
lines changed

11 files changed

+1357
-283
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu
1717
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
1818
from onnxscript.rewriter.ort_fusions.erfgelu import fuse_erfgelu
19+
from onnxscript.rewriter.ort_fusions.fuse_mha_bias import fuse_mha_bias
1920
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
2021
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
2122
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
@@ -79,6 +80,8 @@ def fuse(func, apply_shape_inference: bool = False):
7980
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
8081
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
8182
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
83+
# We apply shape inference after the SDPA fusion as new nodes are added
84+
# in the rewrite rule for certain patterns of SDPA.
8285
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
8386
# Optimize to avoid trying multiple attention-based fusions
8487
fusion_count["mha"] = fuse(fuse_mha)
@@ -87,8 +90,10 @@ def fuse(func, apply_shape_inference: bool = False):
8790
# and avoid trying the attention fusion.
8891
fusion_count["gqa"] = fuse(fuse_gqa)
8992
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
93+
fusion_count["mha_bias"] = 0
9094
fusion_count["attention"] = 0
9195
else:
96+
fusion_count["mha_bias"] = fuse(fuse_mha_bias)
9297
fusion_count["attention"] = fuse(fuse_attention)
9398
fusion_count["gqa"] = 0
9499
fusion_count["gelu"] = fuse(fuse_gelu)

0 commit comments

Comments
 (0)