diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index cc10297afe..f4d62880cf 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -45,7 +45,8 @@ def check(self, context, x, y, cst) -> orp.MatchResult: def rewrite(self, op, x, y, cst): value = cst.const_value.numpy() c = float(value[0] if value.shape == (1,) else value) - node = list(x.uses())[0][0] # noqa: RUF015 + assert x.consumers() + node = x.consumers()[0] kwargs = {} alpha = node.attributes.get("alpha", None) @@ -62,7 +63,13 @@ class _TransposeMatMulBase(orp.RewriteRuleClassBase): def check(self, context, x, y) -> orp.MatchResult: check_result = orp.MatchResult() - perm = list((x if self._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015 + # The value: x, y could be consumed by multiple nodes. + nodes = (x if self._pos == 1 else y).consumers() + perm = None + for node in nodes: + if node.op_type == "Transpose": + perm = node.attributes["perm"].value + assert perm is not None, "Transpose node not found." expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] if perm != expected_perm: @@ -70,7 +77,9 @@ def check(self, context, x, y) -> orp.MatchResult: return check_result def rewrite(self, op, x, y): - node = list((x if self._pos == 2 else y).uses())[0][0] # noqa: RUF015 + the_value = x if self._pos == 2 else y + assert the_value.consumers() + node = the_value.consumers()[0] kwargs = {} for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: att = node.attributes.get(name) @@ -119,8 +128,10 @@ def pattern(self, op, x, y): def check(self, context, x, y) -> orp.MatchResult: check_result = orp.MatchResult() - matmul = list(x.uses())[0][0] # noqa: RUF015 - transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015 + assert x.consumers() + matmul = x.consumers()[0] + assert matmul.outputs[0].consumers() + transpose = matmul.outputs[0].consumers()[0] perm = transpose.attributes["perm"].value expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] @@ -129,7 +140,8 @@ def check(self, context, x, y) -> orp.MatchResult: return check_result def rewrite(self, op, x, y): - node = list(x.uses())[0][0] # noqa: RUF015 + assert x.consumers() + node = x.consumers()[0] kwargs = {} for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: att = node.attributes.get(name)