-
Notifications
You must be signed in to change notification settings - Fork 72
[rewrite] Specify transpose op when TransposeMatMul checks the pattern #2317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rewrite] Specify transpose op when TransposeMatMul checks the pattern #2317
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2317 +/- ##
=======================================
Coverage 73.76% 73.77%
=======================================
Files 239 239
Lines 30904 30917 +13
Branches 3494 3496 +2
=======================================
+ Hits 22797 22809 +12
Misses 6907 6907
- Partials 1200 1201 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors several rewrite and check methods to explicitly select Transpose nodes via consumers()
instead of using a blanket uses()
call, and adds logic to handle multiple consumers and specify the correct transpose permutation.
- Switch
uses()
toconsumers()
for node retrieval in rewrite paths - Introduce filtering for
Transpose
nodes incheck
and assert presence - Simplify node lookup logic across multiple methods
Comments suppressed due to low confidence (1)
onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py:79
- [nitpick] Variable name 'the_value' is ambiguous; consider renaming to something more descriptive like 'input_tensor' or 'target_tensor'.
the_value = x if self._pos == 1 else y
@@ -45,7 +45,7 @@ def check(self, context, x, y, cst) -> orp.MatchResult: | |||
def rewrite(self, op, x, y, cst): | |||
value = cst.const_value.numpy() | |||
c = float(value[0] if value.shape == (1,) else value) | |||
node = list(x.uses())[0][0] # noqa: RUF015 | |||
node = x.consumers()[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing consumers()[0] without verifying the list is non-empty can raise an IndexError if there are no consumers. Consider checking the list length and handling the empty case gracefully or returning a MatchResult.fail.
node = x.consumers()[0] | |
consumers = x.consumers() | |
if not consumers: | |
return orp.MatchResult.fail("No consumers found for the input node.") | |
node = consumers[0] |
Copilot uses AI. Check for mistakes.
perm = None | ||
for node in nodes: | ||
if node.op_type == "Transpose": | ||
perm = node.attributes["perm"].value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Loop continues iterating even after finding the Transpose node. Consider breaking out of the loop on first match to avoid unnecessary iterations.
perm = node.attributes["perm"].value | |
perm = node.attributes["perm"].value | |
break |
Copilot uses AI. Check for mistakes.
@@ -119,8 +126,8 @@ def pattern(self, op, x, y): | |||
|
|||
def check(self, context, x, y) -> orp.MatchResult: | |||
check_result = orp.MatchResult() | |||
matmul = list(x.uses())[0][0] # noqa: RUF015 | |||
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015 | |||
matmul = x.consumers()[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly indexing consumers()[0] may fail if there are no consumers. Consider checking for emptiness and returning a fail result if none exist.
matmul = x.consumers()[0] | |
consumers = x.consumers() | |
if not consumers: | |
return check_result.fail("No consumers found for the input node.") | |
matmul = consumers[0] |
Copilot uses AI. Check for mistakes.
Hi @titaiwangms : the rules are somewhat broken. I think a more significant rewrite would be better. Can we remove the matmul fusion rules from the default rule set, so that we get something working first? The rules should be rewritten so that they don't assume anything about how many users() there are. That's much safer and can be done easily. This may be a good exercise for Bhagirath, so I can ask him to look into it. What do you think? |
How do you tell it's broken? FusedMatMul (Transpose+MatMul) appears in Whisper, but I think it has no difference because we also have lifting transpose pass. |
Replaced by #2331 |
Found in Whisper that
list((x if self._pos == 1 else y).uses())[0][0]
could be not op.Transpose. This PR specifies transpose node so that it does not crash with keyerror on 'perm'.