-
Notifications
You must be signed in to change notification settings - Fork 72
[rewrite] Specify transpose op when TransposeMatMul checks the pattern #2317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -45,7 +45,8 @@ def check(self, context, x, y, cst) -> orp.MatchResult: | |||||||||||
def rewrite(self, op, x, y, cst): | ||||||||||||
value = cst.const_value.numpy() | ||||||||||||
c = float(value[0] if value.shape == (1,) else value) | ||||||||||||
node = list(x.uses())[0][0] # noqa: RUF015 | ||||||||||||
assert x.consumers() | ||||||||||||
node = x.consumers()[0] | ||||||||||||
|
||||||||||||
kwargs = {} | ||||||||||||
alpha = node.attributes.get("alpha", None) | ||||||||||||
|
@@ -62,15 +63,23 @@ class _TransposeMatMulBase(orp.RewriteRuleClassBase): | |||||||||||
|
||||||||||||
def check(self, context, x, y) -> orp.MatchResult: | ||||||||||||
check_result = orp.MatchResult() | ||||||||||||
perm = list((x if self._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015 | ||||||||||||
# The value: x, y could be consumed by multiple nodes. | ||||||||||||
nodes = (x if self._pos == 1 else y).consumers() | ||||||||||||
perm = None | ||||||||||||
for node in nodes: | ||||||||||||
if node.op_type == "Transpose": | ||||||||||||
perm = node.attributes["perm"].value | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Loop continues iterating even after finding the Transpose node. Consider breaking out of the loop on first match to avoid unnecessary iterations.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||
assert perm is not None, "Transpose node not found." | ||||||||||||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
expected_perm = list(range(len(perm))) | ||||||||||||
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] | ||||||||||||
if perm != expected_perm: | ||||||||||||
return check_result.fail("Permutation values for Transpose are not correct.") | ||||||||||||
return check_result | ||||||||||||
|
||||||||||||
def rewrite(self, op, x, y): | ||||||||||||
node = list((x if self._pos == 2 else y).uses())[0][0] # noqa: RUF015 | ||||||||||||
the_value = x if self._pos == 2 else y | ||||||||||||
assert the_value.consumers() | ||||||||||||
node = the_value.consumers()[0] | ||||||||||||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
kwargs = {} | ||||||||||||
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: | ||||||||||||
att = node.attributes.get(name) | ||||||||||||
|
@@ -119,8 +128,10 @@ def pattern(self, op, x, y): | |||||||||||
|
||||||||||||
def check(self, context, x, y) -> orp.MatchResult: | ||||||||||||
check_result = orp.MatchResult() | ||||||||||||
matmul = list(x.uses())[0][0] # noqa: RUF015 | ||||||||||||
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015 | ||||||||||||
assert x.consumers() | ||||||||||||
matmul = x.consumers()[0] | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Directly indexing consumers()[0] may fail if there are no consumers. Consider checking for emptiness and returning a fail result if none exist.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||
assert matmul.outputs[0].consumers() | ||||||||||||
transpose = matmul.outputs[0].consumers()[0] | ||||||||||||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
perm = transpose.attributes["perm"].value | ||||||||||||
expected_perm = list(range(len(perm))) | ||||||||||||
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] | ||||||||||||
|
@@ -129,7 +140,8 @@ def check(self, context, x, y) -> orp.MatchResult: | |||||||||||
return check_result | ||||||||||||
|
||||||||||||
def rewrite(self, op, x, y): | ||||||||||||
node = list(x.uses())[0][0] # noqa: RUF015 | ||||||||||||
assert x.consumers() | ||||||||||||
node = x.consumers()[0] | ||||||||||||
kwargs = {} | ||||||||||||
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: | ||||||||||||
att = node.attributes.get(name) | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing consumers()[0] without verifying the list is non-empty can raise an IndexError if there are no consumers. Consider checking the list length and handling the empty case gracefully or returning a MatchResult.fail.
Copilot uses AI. Check for mistakes.