-
Notifications
You must be signed in to change notification settings - Fork 72
Fusion for partial rotary embedding #2095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Overview
This PR adds support for partial rotary embedding by introducing a new fusion rule and a corresponding test case. Key changes include:
- Introducing a partial rotary embedding fusion rule in the rotary_embedding module.
- Adding a new test case for partial rotary embedding in cos_sin_cache_test.
- Updating function signatures (e.g. fuse_cos_sin_cache now requires a debug flag) and enhancing utility functions.
Reviewed Changes
File | Description |
---|---|
onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py | Added a new test method and updated import/test case list to include partial rotary embedding. |
onnxscript/rewriter/ort_fusions/cos_sin_cache.py | Modified fuse_cos_sin_cache signature to require a debug flag and added a new rewrite rule. |
onnxscript/rewriter/ort_fusions/_rotary_embedding_models.py | Added a partial rotary embedding test case and supporting class. |
onnxscript/rewriter/ort_fusions/rotary_embedding.py | Added PartialRotaryEmbeddingFusion rule and updated the fusion functions. |
onnxscript/rewriter/_ir_utils.py | Updated get_singleton_value to optionally validate tensor rank. |
onnxscript/rewriter/llama_rule_sets.py | Updated rule set to include squeeze_reshape_1d_rule. |
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
❌ 6 Tests Failed:
View the top 1 failed test(s) by shortest run time
View the full list of 2 ❄️ flaky tests
To view more test analytics, go to the Test Analytics Dashboard |
Add a fusion rule for recognizing partial rotary embedding, along with test case.