Skip to content

Fix fused matmul check/rewrite functions #2331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Jun 6, 2025

Conversation

bmehta001
Copy link
Contributor

@bmehta001 bmehta001 commented May 22, 2025

  • Patterns now declare _outputs filters to bind intermediate values
  • Rewrites use fused.producer() or transposed.producer() instead of scanning .uses() which may pick up other nodes that use x or y
  • For ir.Value parameters, use a default of None in case the parameter does not exist
  • Attribute extraction updated to use as_float() / as_ints() for type safety
  • Since rewrite/check functions will have all ir.Value variables passed in, but they may not use all variables, use **_ to read in unused variables
  • Updated docstrings from "by" to "with" for clarity and changed fusedmatmul to matmul where appropriate
  • Add more patterns:
  1. If Transpose.perm indices are [1:-1, 0, -1] and transBatchA is 0, we can change transBatchA to 1
  2. If Transpose.perm indices are [-2, 0:-2, -1] and transBatchA is 1, we can change transBatchA to 0.
  3. If Transpose.perm indices are [1:, 0] and transBatchA is 0, we can change transBatchA to 1 and transA to 1- transA
  4. If Transpose.perm indices are [-1, 0:-1] and transBatchA is 1, we can change transBatchA to 0 and transA to 1- transA
  5. If Transpose.perm indices are [-1, 1:-1, 0] and transBatchA is 1, we can change transA to 1- transA
  6. And also do all of 1-5 for transBatchB
  • Add tests to make sure above changes work for .producer() and the added conditions related to transBatch
    All tests pass

Copy link

codecov bot commented May 22, 2025

Codecov Report

Attention: Patch coverage is 70.72368% with 89 lines in your changes missing coverage. Please review.

Project coverage is 70.14%. Comparing base (af452c7) to head (514649f).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...ewriter/ort_fusions/fused_matmul_rule_sets_test.py 59.67% 74 Missing and 1 partial ⚠️
...ipt/rewriter/ort_fusions/fused_matmul_rule_sets.py 88.13% 6 Missing and 8 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2331      +/-   ##
==========================================
- Coverage   70.23%   70.14%   -0.09%     
==========================================
  Files         197      197              
  Lines       24748    24964     +216     
  Branches     2652     2667      +15     
==========================================
+ Hits        17381    17512     +131     
- Misses       6446     6524      +78     
- Partials      921      928       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@bmehta001 bmehta001 force-pushed the bhamehta/fusedmatmul_find_ops branch from 6e05537 to 71c1d5c Compare May 22, 2025 19:02
@bmehta001 bmehta001 marked this pull request as ready for review May 28, 2025 21:23
@bmehta001 bmehta001 requested a review from gramalingam May 28, 2025 21:23
@justinchuby justinchuby requested a review from Copilot May 28, 2025 21:28
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes the fused matmul check/rewrite functions by updating argument handling, error messaging, and test coverage while extending support for batch transpose operations. Key changes include:

  • Updating checks for batch dimensions and transpose operations in the FusedMatMul operator.
  • Refactoring rewrite rules to use explicit attribute extraction via fused.producer() and adding new rules for flipping transBatch and trans positions.
  • Enhancing the test suite with new models to verify the behavior when fused matmul is combined with intermediate Transpose or MatMul nodes.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py Updated tests to cover additional fused matmul cases and introduced new models for batch handling
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py Refactored rewrite rules with clearer attribute extraction and added new rules for handling transBatch and transpose flipping
Comments suppressed due to low confidence (1)

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py:195

  • [nitpick] The variable name 'property' is ambiguous; consider renaming it (e.g., 'transBatchProperty') to improve clarity.
if self._pos == 1:
            property = "transBatchA"
        else:
            property = "transBatchB"

@justinchuby
Copy link
Collaborator

Thanks - also feel free to point out things that are unintuitive or make suggestions as you use the onnx ir and rewriter apis. Would love to hear about them!

@bmehta001 bmehta001 requested a review from Copilot May 29, 2025 15:26
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances fused MatMul rewrite rules by leveraging producer() calls, strong-typed attribute extraction, and adds comprehensive batch-transpose patterns along with corresponding unit tests.

  • Introduced get_node/get_kwargs helpers and updated rewrite/check logic to use .producer() and as_float/as_ints.
  • Extended transpose‐fusion rules to support flipping batch and matrix transpose flags across six new scenarios.
  • Updated tests to pass explicit transA, transB, transBatchA, transBatchB defaults and added models for in-middle and batch-transpose cases.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
fused_matmul_rule_sets.py Added helper functions; updated pattern checks and rewrites to use producer-based lookup and strong-typed attributes; introduced new batch-transpose rules
fused_matmul_rule_sets_test.py Expanded _run helper to handle batch‐transpose; added tests for fusion in middle of graph and all batch-transpose scenarios

@bmehta001 bmehta001 requested a review from Copilot May 29, 2025 16:32
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes the fused matmul check and rewrite functions by updating attribute extraction for type safety, modifying rewrite patterns to use the new fused producer functions, and enhancing unit tests to cover various transposition and batching cases.

  • Updated attribute extraction using as_float() and as_ints()
  • Revised patterns to explicitly tag intermediate outputs with _outputs
  • Added new test cases for handling transposition in models with batch dimensions

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py Updates tests to validate batch and transpose conditions with fused matmul rewrites
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py Adjustments to check/rewrite rules, improved helper functions, and new rewrite patterns for batch and transpose handling

gramalingam and others added 9 commits June 5, 2025 17:35
Remove the need for many different rules for SDPA fusion by (a) Using
pattern-disjunction, and (b) Simplifying the handling of scaling factors
which can occur in several forms (using either multiplication or
division, either separately to query and/or key, or to the product of
query and key).

Also: simplify the way shapes are checked and error messages are
generated.

---------

Signed-off-by: Ganesan Ramalingam <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: Ti-Tai Wang <[email protected]>
Fix microsoft#2105 

For the logic, this PR follows
https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/dialect/common/cse_pass.py.

Essentially, this PR traverses the original graph and examines whether
the values or the nodes are duplicated. If it's not, the value or the
node is saved in mappings, and added to the new graph. If it is
duplicated, the value or the node is replaced with the mapped/saved
value or node. (FunctionalPass)

CSE subgraph is not supported:
microsoft#2345.

---------

Co-authored-by: Copilot <[email protected]>
@bmehta001 bmehta001 requested a review from justinchuby June 5, 2025 22:03
@bmehta001 bmehta001 enabled auto-merge (squash) June 6, 2025 02:22
@bmehta001 bmehta001 merged commit 5293005 into microsoft:main Jun 6, 2025
26 of 32 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Jun 6, 2025
@bmehta001 bmehta001 deleted the bhamehta/fusedmatmul_find_ops branch June 6, 2025 03:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

4 participants