Skip to content

SDPA fusion cleanup #2352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 3, 2025
Merged

SDPA fusion cleanup #2352

merged 23 commits into from
Jun 3, 2025

Conversation

gramalingam
Copy link
Collaborator

@gramalingam gramalingam commented May 28, 2025

Remove the need for many different rules for SDPA fusion by (a) Using pattern-disjunction, and (b) Simplifying the handling of scaling factors which can occur in several forms (using either multiplication or division, either separately to query and/or key, or to the product of query and key).

Also: simplify the way shapes are checked and error messages are generated.

Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Copy link

codecov bot commented May 28, 2025

Codecov Report

Attention: Patch coverage is 77.35849% with 24 lines in your changes missing coverage. Please review.

Project coverage is 70.20%. Comparing base (4e526f7) to head (06d3e9a).
Report is 1 commits behind head on main.

✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/rewriter/ort_fusions/sdpa_test.py 60.52% 14 Missing and 1 partial ⚠️
onnxscript/rewriter/_fusion_utils.py 63.63% 2 Missing and 2 partials ⚠️
onnxscript/rewriter/ort_fusions/sdpa.py 89.47% 2 Missing and 2 partials ⚠️
onnxscript/rewriter/_basics.py 90.90% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2352      +/-   ##
==========================================
+ Coverage   70.16%   70.20%   +0.04%     
==========================================
  Files         198      198              
  Lines       24844    24871      +27     
  Branches     2670     2659      -11     
==========================================
+ Hits        17431    17461      +30     
- Misses       6495     6496       +1     
+ Partials      918      914       -4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Ganesan Ramalingam <[email protected]>
@gramalingam gramalingam changed the title SDPA fusion cleanup [DRAFT] SDPA fusion cleanup May 28, 2025
@gramalingam gramalingam marked this pull request as draft May 28, 2025 23:13
@gramalingam gramalingam changed the title [DRAFT] SDPA fusion cleanup SDPA fusion cleanup May 29, 2025
@gramalingam gramalingam marked this pull request as ready for review May 29, 2025 21:52
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
@justinchuby justinchuby requested a review from Copilot May 30, 2025 00:25
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR cleans up the SDPA fusion logic by consolidating multiple scaling rules into a single, simplified pattern-disjunction approach and by streamlining shape checks and error message generation.

  • Updated custom scale factor computations in SDPA test scripts.
  • Simplified SDPA rewrite and check functions as well as the shape checking logic.
  • Adjusted GQA test tensor declarations and modularized related fusion utilities.

Reviewed Changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxscript/rewriter/ort_fusions/sdpa_test.py Updated custom scale factor computation for tests.
onnxscript/rewriter/ort_fusions/sdpa.py Simplified SDPA rewrite & check logic using disjunction patterns.
onnxscript/rewriter/ort_fusions/gqa_test.py Added tensor value_info updates for key/value tensors.
onnxscript/rewriter/_rewrite_rule.py Wrapped condition-checking logic with try/except handling.
onnxscript/re-writer/_fusion_utils.py Introduced a stricter shape check function with improved error messaging.
onnxscript/rewriter/_basics.py Added MatchFailure error classes for pattern matching failures.

gramalingam and others added 2 commits May 30, 2025 20:58
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR cleans up the SDPA fusion logic by consolidating many specialized rules into a single pattern with disjunctions, unifying scale-factor handling, and adding explicit shape checks and failure reporting.

  • Use pattern.OrValue to match mul/div/none cases for query, key, and Q·K scaling instead of multiple separate rules.
  • Introduce a helper get_scale_value and simplify default‐vs‐custom scale logic.
  • Centralize shape validation via a new check_shape in _fusion_utils, and enhance match failures handling.

Reviewed Changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.

Show a summary per file
File Description
sdpa_test.py Updated custom scale constants and simplified fusion tests
sdpa.py Replaced many rules with a single SDPA.rule() using OrValue, refactored check/rewrite
gqa_test.py Added total_seqlen and corresponding value_info entries
onnxscript/rewriter/_rewrite_rule.py Wrapped condition checks in try/except to capture MatchFailureError
onnxscript/rewriter/_fusion_utils.py Added check_shape utility raising MatchFailureError
onnxscript/rewriter/_basics.py Introduced MatchFailureInfo and MatchFailureError classes
Comments suppressed due to low confidence (2)

onnxscript/rewriter/_fusion_utils.py:13

  • Union is not imported in this module, leading to a NameError. Add from typing import Union at the top.
Dim = Union[int, ir.SymbolicDim]

onnxscript/rewriter/_rewrite_rule.py:183

  • MatchFailureError defines failure_sources, not failure_nodes_and_values. This line should use e.failure_sources or expose a matching property on the exception.
list(e.failure_nodes_and_values),

@titaiwangms titaiwangms self-requested a review June 2, 2025 17:39
Signed-off-by: Ganesan Ramalingam <[email protected]>
Signed-off-by: Ganesan Ramalingam <[email protected]>
Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gramalingam gramalingam enabled auto-merge (squash) June 3, 2025 19:27
@gramalingam gramalingam merged commit e49620a into main Jun 3, 2025
28 of 32 checks passed
@gramalingam gramalingam deleted the rama/sdpa branch June 3, 2025 19:28
@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Jun 3, 2025
bmehta001 pushed a commit to bmehta001/onnxscript that referenced this pull request Jun 5, 2025
Remove the need for many different rules for SDPA fusion by (a) Using
pattern-disjunction, and (b) Simplifying the handling of scaling factors
which can occur in several forms (using either multiplication or
division, either separately to query and/or key, or to the product of
query and key).

Also: simplify the way shapes are checked and error messages are
generated.

---------

Signed-off-by: Ganesan Ramalingam <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: Ti-Tai Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

3 participants