-
Notifications
You must be signed in to change notification settings - Fork 72
Fix fused matmul check/rewrite functions #2331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix fused matmul check/rewrite functions #2331
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2331 +/- ##
==========================================
- Coverage 70.23% 70.14% -0.09%
==========================================
Files 197 197
Lines 24748 24964 +216
Branches 2652 2667 +15
==========================================
+ Hits 17381 17512 +131
- Misses 6446 6524 +78
- Partials 921 928 +7 ☔ View full report in Codecov by Sentry. |
6e05537
to
71c1d5c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes the fused matmul check/rewrite functions by updating argument handling, error messaging, and test coverage while extending support for batch transpose operations. Key changes include:
- Updating checks for batch dimensions and transpose operations in the FusedMatMul operator.
- Refactoring rewrite rules to use explicit attribute extraction via fused.producer() and adding new rules for flipping transBatch and trans positions.
- Enhancing the test suite with new models to verify the behavior when fused matmul is combined with intermediate Transpose or MatMul nodes.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py | Updated tests to cover additional fused matmul cases and introduced new models for batch handling |
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py | Refactored rewrite rules with clearer attribute extraction and added new rules for handling transBatch and transpose flipping |
Comments suppressed due to low confidence (1)
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py:195
- [nitpick] The variable name 'property' is ambiguous; consider renaming it (e.g., 'transBatchProperty') to improve clarity.
if self._pos == 1:
property = "transBatchA"
else:
property = "transBatchB"
Thanks - also feel free to point out things that are unintuitive or make suggestions as you use the onnx ir and rewriter apis. Would love to hear about them! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances fused MatMul rewrite rules by leveraging producer()
calls, strong-typed attribute extraction, and adds comprehensive batch-transpose patterns along with corresponding unit tests.
- Introduced
get_node
/get_kwargs
helpers and updated rewrite/check logic to use.producer()
andas_float
/as_ints
. - Extended transpose‐fusion rules to support flipping batch and matrix transpose flags across six new scenarios.
- Updated tests to pass explicit
transA
,transB
,transBatchA
,transBatchB
defaults and added models for in-middle and batch-transpose cases.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
File | Description |
---|---|
fused_matmul_rule_sets.py | Added helper functions; updated pattern checks and rewrites to use producer-based lookup and strong-typed attributes; introduced new batch-transpose rules |
fused_matmul_rule_sets_test.py | Expanded _run helper to handle batch‐transpose; added tests for fusion in middle of graph and all batch-transpose scenarios |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes the fused matmul check and rewrite functions by updating attribute extraction for type safety, modifying rewrite patterns to use the new fused producer functions, and enhancing unit tests to cover various transposition and batching cases.
- Updated attribute extraction using as_float() and as_ints()
- Revised patterns to explicitly tag intermediate outputs with _outputs
- Added new test cases for handling transposition in models with batch dimensions
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py | Updates tests to validate batch and transpose conditions with fused matmul rewrites |
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py | Adjustments to check/rewrite rules, improved helper functions, and new rewrite patterns for batch and transpose handling |
Remove the need for many different rules for SDPA fusion by (a) Using pattern-disjunction, and (b) Simplifying the handling of scaling factors which can occur in several forms (using either multiplication or division, either separately to query and/or key, or to the product of query and key). Also: simplify the way shapes are checked and error messages are generated. --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: Ti-Tai Wang <[email protected]>
Fix microsoft#2105 For the logic, this PR follows https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/dialect/common/cse_pass.py. Essentially, this PR traverses the original graph and examines whether the values or the nodes are duplicated. If it's not, the value or the node is saved in mappings, and added to the new graph. If it is duplicated, the value or the node is replaced with the mapped/saved value or node. (FunctionalPass) CSE subgraph is not supported: microsoft#2345. --------- Co-authored-by: Copilot <[email protected]>
Use get_ints Co-authored-by: Justin Chu <[email protected]>
Rm _get_ints Co-authored-by: Justin Chu <[email protected]>
.producer()
and the added conditions related totransBatch
All tests pass