Skip to content

[rewrite] Specify transpose op when TransposeMatMul checks the pattern #2317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def check(self, context, x, y, cst) -> orp.MatchResult:
def rewrite(self, op, x, y, cst):
value = cst.const_value.numpy()
c = float(value[0] if value.shape == (1,) else value)
node = list(x.uses())[0][0] # noqa: RUF015
assert x.consumers()
node = x.consumers()[0]
Copy link
Preview

Copilot AI May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing consumers()[0] without verifying the list is non-empty can raise an IndexError if there are no consumers. Consider checking the list length and handling the empty case gracefully or returning a MatchResult.fail.

Suggested change
node = x.consumers()[0]
consumers = x.consumers()
if not consumers:
return orp.MatchResult.fail("No consumers found for the input node.")
node = consumers[0]

Copilot uses AI. Check for mistakes.


kwargs = {}
alpha = node.attributes.get("alpha", None)
Expand All @@ -62,15 +63,23 @@ class _TransposeMatMulBase(orp.RewriteRuleClassBase):

def check(self, context, x, y) -> orp.MatchResult:
check_result = orp.MatchResult()
perm = list((x if self._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015
# The value: x, y could be consumed by multiple nodes.
nodes = (x if self._pos == 1 else y).consumers()
perm = None
for node in nodes:
if node.op_type == "Transpose":
perm = node.attributes["perm"].value
Copy link
Preview

Copilot AI May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Loop continues iterating even after finding the Transpose node. Consider breaking out of the loop on first match to avoid unnecessary iterations.

Suggested change
perm = node.attributes["perm"].value
perm = node.attributes["perm"].value
break

Copilot uses AI. Check for mistakes.

assert perm is not None, "Transpose node not found."
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
if perm != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")
return check_result

def rewrite(self, op, x, y):
node = list((x if self._pos == 2 else y).uses())[0][0] # noqa: RUF015
the_value = x if self._pos == 2 else y
assert the_value.consumers()
node = the_value.consumers()[0]
kwargs = {}
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
Expand Down Expand Up @@ -119,8 +128,10 @@ def pattern(self, op, x, y):

def check(self, context, x, y) -> orp.MatchResult:
check_result = orp.MatchResult()
matmul = list(x.uses())[0][0] # noqa: RUF015
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015
assert x.consumers()
matmul = x.consumers()[0]
Copy link
Preview

Copilot AI May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly indexing consumers()[0] may fail if there are no consumers. Consider checking for emptiness and returning a fail result if none exist.

Suggested change
matmul = x.consumers()[0]
consumers = x.consumers()
if not consumers:
return check_result.fail("No consumers found for the input node.")
matmul = consumers[0]

Copilot uses AI. Check for mistakes.

assert matmul.outputs[0].consumers()
transpose = matmul.outputs[0].consumers()[0]
perm = transpose.attributes["perm"].value
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
Expand All @@ -129,7 +140,8 @@ def check(self, context, x, y) -> orp.MatchResult:
return check_result

def rewrite(self, op, x, y):
node = list(x.uses())[0][0] # noqa: RUF015
assert x.consumers()
node = x.consumers()[0]
kwargs = {}
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
Expand Down
Loading