Skip to content

Fix fused matmul check/rewrite functions #2331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
71c1d5c
Use producer syntax, simplify, rm unnecessary checks
bmehta001 May 22, 2025
77b39f6
Simplify assert, assigning attributes
bmehta001 May 27, 2025
af0abbd
Add test to ensure fusion rules do not rely on position of node
bmehta001 May 27, 2025
65f4637
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 May 28, 2025
8946821
Add checking for transBatch
bmehta001 May 28, 2025
0f2b287
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 May 28, 2025
cb5a192
Fix condition, formatting, and add default
bmehta001 May 28, 2025
b2e9737
Fix formatting
bmehta001 May 28, 2025
af9c064
Simplify syntax w/ functions
bmehta001 May 29, 2025
4611fdc
Condense rules using type function and classVars
bmehta001 May 29, 2025
bee83a8
Rm unused/fix comment
bmehta001 May 29, 2025
9f81fe9
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 May 29, 2025
b9abd98
Address comments
bmehta001 May 29, 2025
10162db
Fix None error + rm type-ignore
bmehta001 May 30, 2025
d06f388
Add more tests
bmehta001 May 30, 2025
a065541
Handle defaults and add docstring
bmehta001 May 30, 2025
1a90317
Add clarifying comment
bmehta001 May 30, 2025
d417c02
Fix correct default behavior for transpose
bmehta001 Jun 3, 2025
a23ee07
Formally drop python 3.8 support (#2354)
justinchuby May 29, 2025
d1eb856
Implement `__repr__` for MatchResult (#2353)
justinchuby May 30, 2025
19b7f6a
Use onnx_ir as a dependency (#2324)
justinchuby May 30, 2025
3fd79be
Support common subexpression elimination pass (CSE) (#2304)
titaiwangms May 30, 2025
11075ee
Fix pytest for TestCosSinCacheTransform (#2358)
justinchuby Jun 2, 2025
9b81926
SDPA fusion cleanup (#2352)
gramalingam Jun 3, 2025
7553ce1
Require onnx-ir 0.1.1 (#2360)
justinchuby Jun 3, 2025
73432e5
Enable CSE in optimizer (#2361)
titaiwangms Jun 4, 2025
ccce52e
Rewrite tests and address comments
bmehta001 Jun 5, 2025
3654fa8
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 Jun 5, 2025
12b4cc0
Support common subexpression elimination pass (CSE) (#2304)
titaiwangms May 30, 2025
2276a16
Enable CSE in optimizer (#2361)
titaiwangms Jun 4, 2025
2a0a798
Revert changes
bmehta001 Jun 5, 2025
5bcf2b1
Fix errors/simplify
bmehta001 Jun 5, 2025
a94c295
Update onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
bmehta001 Jun 5, 2025
e976fb1
Update onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
bmehta001 Jun 5, 2025
2adf8ea
Iterate through IR Model instead of ModelProto
bmehta001 Jun 5, 2025
73889d3
Addressed comments
bmehta001 Jun 5, 2025
841c49b
Simplify
bmehta001 Jun 5, 2025
04e6955
Simplify use of constants
bmehta001 Jun 5, 2025
514649f
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 Jun 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 241 additions & 44 deletions onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from typing import ClassVar

import onnxscript.rewriter.pattern as orp
from onnxscript import ir


class FusedMatMulDiv1(orp.RewriteRuleClassBase):
"""Replaces ``MatMul + Div`` by FusedMatMul."""
"""Replaces ``MatMul + Div`` with FusedMatMul."""

def pattern(self, op, x, y, cst):
return op.Div(op.MatMul(x, y), cst)
Expand All @@ -29,122 +30,312 @@


class FusedMatMulDiv2(orp.RewriteRuleClassBase):
"""Replaces ``FusedMatMul + Div`` by FusedMatMul."""
"""Replaces ``FusedMatMul + Div`` with FusedMatMul."""

def pattern(self, op, x, y, cst):
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst)
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), cst)

def check(self, context, x, y, cst) -> orp.MatchResult:
def check(self, context, x, y, cst, fused: ir.Value) -> orp.MatchResult:
check_result = orp.MatchResult()
if cst.const_value is None:
return check_result.fail("Divisor is not a constant value.")
if cst.const_value.numpy().size > 1:
return check_result.fail("Divisor is not a scalar value.")
return check_result

def rewrite(self, op, x, y, cst):
def rewrite(self, op, x, y, cst, fused: ir.Value):
value = cst.const_value.numpy()
c = float(value[0] if value.shape == (1,) else value)
node = list(x.uses())[0][0] # noqa: RUF015

kwargs = {}
alpha = node.attributes.get("alpha", None)
kwargs["alpha"] = alpha.value / c if alpha else 1.0 / c
for name in ["transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
node = fused.producer()
assert node is not None, "FusedMatMul node should not be None"
kwargs = {key: val.value for key, val in node.attributes.items()}
kwargs["alpha"] = node.attributes["alpha"].as_float() / c
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")


class _TransposeMatMulBase(orp.RewriteRuleClassBase):
_pos: ClassVar = 1

def check(self, context, x, y) -> orp.MatchResult:
def check(
self, context, x, y, transposed: ir.Value, fused: ir.Value | None = None, **_
) -> orp.MatchResult:
check_result = orp.MatchResult()
perm = list((x if self._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015
node = transposed.producer()
assert node is not None, "Transpose node should not be None"
perm = node.attributes["perm"].as_ints()
# Check that last two dimensions are swapped
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
if perm != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")
if fused:
fused_node = fused.producer()
assert fused_node is not None, "FusedMatMul node should not be None"
if fused_node.attributes["transBatchA"].as_int() == 1 and self._pos == 2:
return check_result.fail(

Check warning on line 75 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L75

Added line #L75 was not covered by tests
"FusedMatMul with transBatchA cannot be used with Transpose(A)."
)
if fused_node.attributes["transBatchB"].as_int() == 1 and self._pos == 1:
return check_result.fail(

Check warning on line 79 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L79

Added line #L79 was not covered by tests
"FusedMatMul with transBatchB cannot be used with Transpose(B)."
)
return check_result

def rewrite(self, op, x, y):
node = list((x if self._pos == 2 else y).uses())[0][0] # noqa: RUF015
def rewrite(self, op, x, y, fused: ir.Value | None = None, **_):
kwargs = {}
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
if fused:
node = fused.producer()
assert node is not None, "FusedMatMul node should not be None"
kwargs = {key: val.value for key, val in node.attributes.items()}
name = "transA" if self._pos == 1 else "transB"
kwargs[name] = 1 - kwargs.get(name, 0)
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")


class TransposeMatMul1(_TransposeMatMulBase):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
"""Replaces ``Transpose + MatMul`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.MatMul(op.Transpose(x), y)
return op.MatMul(op.Transpose(x, _outputs=["transposed"]), y)


class TransposeFusedMatMul1(TransposeMatMul1):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
"""Replaces ``Transpose + (Fused)MatMul`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft")
return op.FusedMatMul(
op.Transpose(x, _outputs=["transposed"]),
y,
_domain="com.microsoft",
_outputs=["fused"],
)


class TransposeMatMul2(_TransposeMatMulBase):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
"""Replaces ``Transpose + MatMul`` with FusedMatMul."""

_pos: ClassVar = 2

def pattern(self, op, x, y):
return op.MatMul(x, op.Transpose(y))
return op.MatMul(x, op.Transpose(y, _outputs=["transposed"]))


class TransposeFusedMatMul2(TransposeMatMul2):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
"""Replaces ``Transpose + (Fused)MatMul`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft")
return op.FusedMatMul(
x,
op.Transpose(y, _outputs=["transposed"]),
_domain="com.microsoft",
_outputs=["fused"],
)


class _TransposeFusedMatMulBaseWithBatch(orp.RewriteRuleClassBase):
"""Base class for Transpose + FusedMatMul with batch transpose support."""

_pos: ClassVar = 1
_flip_transpose_batch: ClassVar = False
_flip_transpose: ClassVar = False

def rewrite(self, op, x, y, fused: ir.Value, **_):
kwargs = {}
node = fused.producer()
assert node is not None, "FusedMatMul node should not be None"
kwargs = {key: val.value for key, val in node.attributes.items()}
name = "A" if self._pos == 1 else "B"
if self._flip_transpose_batch:
transBatchName = f"transBatch{name}"
kwargs[transBatchName] = 1 - kwargs[transBatchName]
if self._flip_transpose:
transName = f"trans{name}"
kwargs[transName] = 1 - kwargs[transName]
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")

def pattern(self, op, x, y):
if self._pos == 1:
return op.FusedMatMul(
op.Transpose(x, _outputs=["transposed"]),
y,
_domain="com.microsoft",
_outputs=["fused"],
)
else:
return op.FusedMatMul(
x,
op.Transpose(y, _outputs=["transposed"]),
_domain="com.microsoft",
_outputs=["fused"],
)


class TransposeFusedMatMulWithFlippedBatch1(_TransposeFusedMatMulBaseWithBatch):
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul.
This rule is for when only transBatchA can be flipped i.e.,
when the transpose indices are [1:-1, 0, -1] for transBatchA = 0 and
[-2, 0:-2, -1] for transBatchA = 1.
"""

_flip_transpose_batch: ClassVar = True

def check(
self, context, x, y, transposed: ir.Value, fused: ir.Value, **_
) -> orp.MatchResult:
check_result = orp.MatchResult()
node = transposed.producer()
assert node is not None, "Transpose node should not be None"
fused_node = fused.producer()
assert fused_node is not None, "FusedMatMul node should not be None"
perm = node.attributes["perm"].as_ints()
# Check that last two dimensions are swapped
list_perm = list(range(len(perm)))
expected_perm0 = list_perm[1:-1] + [list_perm[0], list_perm[-1]]
expected_perm1 = [list_perm[-2]] + list_perm[0:-2] + [list_perm[-1]]
if self._pos == 1:
property = "transBatchA"
else:
property = "transBatchB"

Check warning on line 198 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L198

Added line #L198 was not covered by tests
transBatch = fused_node.attributes[property].as_int()
if (expected_perm0 == perm and transBatch == 0) or (
expected_perm1 == perm and transBatch == 1
):
return check_result
return check_result.fail("Permutation values for Transpose are not correct.")


class TransposeFusedMatMulWithFlippedBatch2(_TransposeFusedMatMulBaseWithBatch):
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul.
This rule is for when only transBatchB can be flipped i.e.,
when the transpose indices are [1:-1, 0, -1] for transBatchB = 0 and
[-2, 0:-2, -1] for transBatchB = 1.
"""

_pos: ClassVar = 2


class TransposeFusedMatMulWithFlippedBatchAndTranspose1(_TransposeFusedMatMulBaseWithBatch):
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul.
This rule is for when transBatchA and transA can be flipped i.e.,
when the transpose indices are [1:-1, -1, 0] for transBatchA = 0 and
[-2, 0:-2, -1] for transBatchA = 1.
"""

_flip_transpose_batch: ClassVar = True
_flip_transpose: ClassVar = True

def check(
self, context, x, y, transposed: ir.Value, fused: ir.Value, **_
) -> orp.MatchResult:
check_result = orp.MatchResult()
node = transposed.producer()
assert node is not None, "Transpose node should not be None"
fused_node = fused.producer()
assert fused_node is not None, "FusedMatMul node should not be None"
perm = node.attributes["perm"].as_ints()
# Check that last two dimensions are swapped
list_perm = list(range(len(perm)))
expected_perm0 = list_perm[1:] + [list_perm[0]]
expected_perm1 = [list_perm[-1]] + list_perm[0:-1]
if self._pos == 1:
property = "transBatchA"
else:
property = "transBatchB"

Check warning on line 243 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L243

Added line #L243 was not covered by tests
transBatch = fused_node.attributes[property].as_int()
if (expected_perm0 == perm and transBatch == 0) or (
expected_perm1 == perm and transBatch == 1
):
return check_result
return check_result.fail("Permutation values for Transpose are not correct.")


class TransposeFusedMatMulWithFlippedBatchAndTranspose2(
TransposeFusedMatMulWithFlippedBatchAndTranspose1
):
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul.
This rule is for when transBatchB and transB can be flipped i.e.,
when the transpose indices are [1:-1, -1, 0] for transBatchB = 0 and
[-2, 0:-2, -1] for transBatchB = 1.
"""

_pos: ClassVar = 2


class TransposeFusedMatMulWithBatchAndTranspose1(_TransposeFusedMatMulBaseWithBatch):
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul.
This rule is for when transBatchA = 1 and transA can be flipped i.e.,
when the transpose indices are [-1, 1:-1, 0].
"""

_flip_transpose: ClassVar = True

def check(
self, context, x, y, transposed: ir.Value, fused: ir.Value, **_
) -> orp.MatchResult:
check_result = orp.MatchResult()
node = transposed.producer()
assert node is not None, "Transpose node should not be None"
fused_node = fused.producer()
assert fused_node is not None, "FusedMatMul node should not be None"
perm = node.attributes["perm"].as_ints()
# Check that last two dimensions are swapped
list_perm = list(range(len(perm)))
expected_perm = [list_perm[-1]] + list_perm[1:-1] + [list_perm[0]]
if self._pos == 1:
property = "transBatchA"
else:
property = "transBatchB"

Check warning on line 287 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L287

Added line #L287 was not covered by tests
transBatch = fused_node.attributes[property].as_int()
if expected_perm == perm and transBatch == 1:
return check_result
return check_result.fail("Permutation values for Transpose are not correct.")

Check warning on line 291 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L291

Added line #L291 was not covered by tests


class TransposeFusedMatMulWithBatchAndTranspose2(TransposeFusedMatMulWithBatchAndTranspose1):
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul.
This rule is for when transBatchB = 1 and transB can be flipped i.e.,
when the transpose indices are [-1, 1:-1, 0].
"""

_pos: ClassVar = 2


class MatMulTranspose(orp.RewriteRuleClassBase):
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""
"""Replaces ``MatMul + Transpose`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.Transpose(op.MatMul(x, y))
return op.Transpose(op.MatMul(x, y), _outputs=["transposed"])

def check(self, context, x, y) -> orp.MatchResult:
def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
check_result = orp.MatchResult()
matmul = list(x.uses())[0][0] # noqa: RUF015
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015
perm = transpose.attributes["perm"].value
transpose = transposed.producer()
assert transpose is not None, "Transpose node should not be None"
perm = transpose.attributes["perm"].as_ints()
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
if perm != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")
return check_result

def rewrite(self, op, x, y):
node = list(x.uses())[0][0] # noqa: RUF015
def rewrite(self, op, x, y, fused: ir.Value | None = None, **_):
kwargs = {}
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
if fused:
node = fused.producer()
assert node is not None, "FusedMatMul node should not be None"
kwargs = {key: val.value for key, val in node.attributes.items()}
for name in ["transA", "transB"]:
kwargs[name] = 1 - kwargs.get(name, 0)
return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft")


class FusedMatMulTranspose(MatMulTranspose):
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""
"""Replaces ``FusedMatMul + Transpose`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft"))
return op.Transpose(
op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]),
_outputs=["transposed"],
)


def fused_matmul_rule_sets() -> orp.RewriteRuleSet:
Expand All @@ -165,5 +356,11 @@
TransposeFusedMatMul1.rule(),
TransposeMatMul2.rule(),
TransposeFusedMatMul2.rule(),
TransposeFusedMatMulWithFlippedBatch1.rule(),
TransposeFusedMatMulWithFlippedBatch2.rule(),
TransposeFusedMatMulWithFlippedBatchAndTranspose1.rule(),
TransposeFusedMatMulWithFlippedBatchAndTranspose2.rule(),
TransposeFusedMatMulWithBatchAndTranspose1.rule(),
TransposeFusedMatMulWithBatchAndTranspose2.rule(),
]
)
Loading
Loading