-
Notifications
You must be signed in to change notification settings - Fork 72
SDPA fusion cleanup #2352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SDPA fusion cleanup #2352
Conversation
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Codecov ReportAttention: Patch coverage is
✅ All tests successful. No failed tests found. Additional details and impacted files@@ Coverage Diff @@
## main #2352 +/- ##
==========================================
+ Coverage 70.16% 70.20% +0.04%
==========================================
Files 198 198
Lines 24844 24871 +27
Branches 2670 2659 -11
==========================================
+ Hits 17431 17461 +30
- Misses 6495 6496 +1
+ Partials 918 914 -4 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR cleans up the SDPA fusion logic by consolidating multiple scaling rules into a single, simplified pattern-disjunction approach and by streamlining shape checks and error message generation.
- Updated custom scale factor computations in SDPA test scripts.
- Simplified SDPA rewrite and check functions as well as the shape checking logic.
- Adjusted GQA test tensor declarations and modularized related fusion utilities.
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
onnxscript/rewriter/ort_fusions/sdpa_test.py | Updated custom scale factor computation for tests. |
onnxscript/rewriter/ort_fusions/sdpa.py | Simplified SDPA rewrite & check logic using disjunction patterns. |
onnxscript/rewriter/ort_fusions/gqa_test.py | Added tensor value_info updates for key/value tensors. |
onnxscript/rewriter/_rewrite_rule.py | Wrapped condition-checking logic with try/except handling. |
onnxscript/re-writer/_fusion_utils.py | Introduced a stricter shape check function with improved error messaging. |
onnxscript/rewriter/_basics.py | Added MatchFailure error classes for pattern matching failures. |
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Signed-off-by: Ganesan Ramalingam <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR cleans up the SDPA fusion logic by consolidating many specialized rules into a single pattern with disjunctions, unifying scale-factor handling, and adding explicit shape checks and failure reporting.
- Use
pattern.OrValue
to match mul/div/none cases for query, key, and Q·K scaling instead of multiple separate rules. - Introduce a helper
get_scale_value
and simplify default‐vs‐custom scale logic. - Centralize shape validation via a new
check_shape
in_fusion_utils
, and enhance match failures handling.
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
sdpa_test.py | Updated custom scale constants and simplified fusion tests |
sdpa.py | Replaced many rules with a single SDPA.rule() using OrValue , refactored check /rewrite |
gqa_test.py | Added total_seqlen and corresponding value_info entries |
onnxscript/rewriter/_rewrite_rule.py | Wrapped condition checks in try/except to capture MatchFailureError |
onnxscript/rewriter/_fusion_utils.py | Added check_shape utility raising MatchFailureError |
onnxscript/rewriter/_basics.py | Introduced MatchFailureInfo and MatchFailureError classes |
Comments suppressed due to low confidence (2)
onnxscript/rewriter/_fusion_utils.py:13
Union
is not imported in this module, leading to aNameError
. Addfrom typing import Union
at the top.
Dim = Union[int, ir.SymbolicDim]
onnxscript/rewriter/_rewrite_rule.py:183
MatchFailureError
definesfailure_sources
, notfailure_nodes_and_values
. This line should usee.failure_sources
or expose a matching property on the exception.
list(e.failure_nodes_and_values),
Co-authored-by: Ti-Tai Wang <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Ganesan Ramalingam <[email protected]>
Remove the need for many different rules for SDPA fusion by (a) Using pattern-disjunction, and (b) Simplifying the handling of scaling factors which can occur in several forms (using either multiplication or division, either separately to query and/or key, or to the product of query and key). Also: simplify the way shapes are checked and error messages are generated. --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: Ti-Tai Wang <[email protected]>
Remove the need for many different rules for SDPA fusion by (a) Using pattern-disjunction, and (b) Simplifying the handling of scaling factors which can occur in several forms (using either multiplication or division, either separately to query and/or key, or to the product of query and key).
Also: simplify the way shapes are checked and error messages are generated.