Skip to content

Add fusion rules (Whisper optimizations) #2221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
35cc855
initial dry run
shubhambhokare1 Apr 16, 2025
01a5af3
undo sdpa changes
shubhambhokare1 Apr 18, 2025
7429819
working fusion
shubhambhokare1 Apr 22, 2025
8ca0a13
new fusions
shubhambhokare1 Apr 23, 2025
742e4f7
some cleanu
shubhambhokare1 Apr 23, 2025
be7465d
lint
shubhambhokare1 Apr 23, 2025
0b1b3bd
rebase modifs
shubhambhokare1 Apr 24, 2025
87197d1
fix pre_mul_q placement
shubhambhokare1 Apr 25, 2025
61fc5d1
fix skip_normalization fusion
shubhambhokare1 Apr 25, 2025
f02d9f6
modify qkv attention axis
shubhambhokare1 Apr 26, 2025
f186d0a
add one layer encoder
shubhambhokare1 Apr 28, 2025
ed5d947
add decoder model
shubhambhokare1 Apr 29, 2025
daa7100
remove comments
shubhambhokare1 Apr 29, 2025
2805de3
add ort testing for whisper encoder/decoder fusion
shubhambhokare1 Apr 29, 2025
67621f3
add skip norm tests
shubhambhokare1 Apr 29, 2025
e518648
add attention tests
shubhambhokare1 Apr 29, 2025
c986021
fuse mha bias modifs
shubhambhokare1 Apr 29, 2025
8b3ede5
rebase changes
shubhambhokare1 May 2, 2025
3749730
fix rebase
shubhambhokare1 May 2, 2025
bfb3ab6
some comments
shubhambhokare1 May 2, 2025
080403e
use pattern.value
shubhambhokare1 May 2, 2025
c9d0483
add attention_bias handling
shubhambhokare1 May 2, 2025
9634bef
rename sdpa params
shubhambhokare1 May 3, 2025
5e40467
rewritten mha fusion
shubhambhokare1 May 3, 2025
b622c80
rewrite cross attention logic
shubhambhokare1 May 5, 2025
4be6ae1
add slice checks
shubhambhokare1 May 6, 2025
ba12b06
lint fixes
shubhambhokare1 May 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.ort_fusions.erfgelu import fuse_erfgelu
from onnxscript.rewriter.ort_fusions.fuse_mha_bias import fuse_mha_bias
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
Expand Down Expand Up @@ -79,6 +80,8 @@
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
# We apply shape inference after the SDPA fusion as new nodes are added
# in the rewrite rule for certain patterns of SDPA.
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
# Optimize to avoid trying multiple attention-based fusions
fusion_count["mha"] = fuse(fuse_mha)
Expand All @@ -87,8 +90,10 @@
# and avoid trying the attention fusion.
fusion_count["gqa"] = fuse(fuse_gqa)
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
fusion_count["mha_bias"] = 0

Check warning on line 93 in onnxscript/rewriter/ort_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L93

Added line #L93 was not covered by tests
fusion_count["attention"] = 0
else:
fusion_count["mha_bias"] = fuse(fuse_mha_bias)
fusion_count["attention"] = fuse(fuse_attention)
fusion_count["gqa"] = 0
fusion_count["gelu"] = fuse(fuse_gelu)
Expand Down
Loading
Loading