Skip to content

Fix fused matmul check/rewrite functions #2331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
71c1d5c
Use producer syntax, simplify, rm unnecessary checks
bmehta001 May 22, 2025
77b39f6
Simplify assert, assigning attributes
bmehta001 May 27, 2025
af0abbd
Add test to ensure fusion rules do not rely on position of node
bmehta001 May 27, 2025
65f4637
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 May 28, 2025
8946821
Add checking for transBatch
bmehta001 May 28, 2025
0f2b287
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 May 28, 2025
cb5a192
Fix condition, formatting, and add default
bmehta001 May 28, 2025
b2e9737
Fix formatting
bmehta001 May 28, 2025
af9c064
Simplify syntax w/ functions
bmehta001 May 29, 2025
4611fdc
Condense rules using type function and classVars
bmehta001 May 29, 2025
bee83a8
Rm unused/fix comment
bmehta001 May 29, 2025
9f81fe9
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 May 29, 2025
b9abd98
Address comments
bmehta001 May 29, 2025
10162db
Fix None error + rm type-ignore
bmehta001 May 30, 2025
d06f388
Add more tests
bmehta001 May 30, 2025
a065541
Handle defaults and add docstring
bmehta001 May 30, 2025
1a90317
Add clarifying comment
bmehta001 May 30, 2025
d417c02
Fix correct default behavior for transpose
bmehta001 Jun 3, 2025
a23ee07
Formally drop python 3.8 support (#2354)
justinchuby May 29, 2025
d1eb856
Implement `__repr__` for MatchResult (#2353)
justinchuby May 30, 2025
19b7f6a
Use onnx_ir as a dependency (#2324)
justinchuby May 30, 2025
3fd79be
Support common subexpression elimination pass (CSE) (#2304)
titaiwangms May 30, 2025
11075ee
Fix pytest for TestCosSinCacheTransform (#2358)
justinchuby Jun 2, 2025
9b81926
SDPA fusion cleanup (#2352)
gramalingam Jun 3, 2025
7553ce1
Require onnx-ir 0.1.1 (#2360)
justinchuby Jun 3, 2025
73432e5
Enable CSE in optimizer (#2361)
titaiwangms Jun 4, 2025
ccce52e
Rewrite tests and address comments
bmehta001 Jun 5, 2025
3654fa8
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 Jun 5, 2025
12b4cc0
Support common subexpression elimination pass (CSE) (#2304)
titaiwangms May 30, 2025
2276a16
Enable CSE in optimizer (#2361)
titaiwangms Jun 4, 2025
2a0a798
Revert changes
bmehta001 Jun 5, 2025
5bcf2b1
Fix errors/simplify
bmehta001 Jun 5, 2025
a94c295
Update onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
bmehta001 Jun 5, 2025
e976fb1
Update onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
bmehta001 Jun 5, 2025
2adf8ea
Iterate through IR Model instead of ModelProto
bmehta001 Jun 5, 2025
73889d3
Addressed comments
bmehta001 Jun 5, 2025
841c49b
Simplify
bmehta001 Jun 5, 2025
04e6955
Simplify use of constants
bmehta001 Jun 5, 2025
514649f
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 Jun 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 253 additions & 55 deletions onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,49 @@
# Licensed under the MIT License.
from __future__ import annotations

from typing import ClassVar
from typing import ClassVar, Optional, Sequence

import onnxscript.rewriter.pattern as orp
from onnxscript import ir


def _get_node(value: ir.Value, name: str) -> ir.Node:
"""Get the node from the output value."""
node = value.producer()
assert node is not None, f"{name} node should not be None"
return node


def _get_kwargs(node: ir.Node) -> dict[str, float | int]:
"""Get the kwargs from the node."""
kwargs = {key: val.value for key, val in node.attributes.items()}
return kwargs


def _get_int_or_default(node: ir.Node, name: str, default: int = 0) -> int:
"""Get the int value from the node attribute dictionary or return default."""
if name in node.attributes:
value = node.attributes[name].as_int()
else:
value = default

Check warning on line 29 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L29

Added line #L29 was not covered by tests
return value


def _get_ints_or_default(
node: ir.Node, name: str, default: Optional[Sequence[int]] = None
) -> Sequence[int]:
"""Get the Sequence[int] value from the node attribute dictionary or return default."""
if name in node.attributes:
value = node.attributes[name].as_ints()
elif default is not None:
value = default

Check warning on line 40 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L40

Added line #L40 was not covered by tests
else:
value = []

Check warning on line 42 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L42

Added line #L42 was not covered by tests
return value


class FusedMatMulDiv1(orp.RewriteRuleClassBase):
"""Replaces ``MatMul + Div`` by FusedMatMul."""
"""Replaces ``MatMul + Div`` with MatMul."""

def pattern(self, op, x, y, cst):
return op.Div(op.MatMul(x, y), cst)
Expand All @@ -29,122 +65,278 @@


class FusedMatMulDiv2(orp.RewriteRuleClassBase):
"""Replaces ``FusedMatMul + Div`` by FusedMatMul."""
"""Replaces ``FusedMatMul + Div`` with FusedMatMul."""

def pattern(self, op, x, y, cst):
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst)
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), cst)

def check(self, context, x, y, cst) -> orp.MatchResult:
def check(self, context, x, y, cst, **_) -> orp.MatchResult:
check_result = orp.MatchResult()
if cst.const_value is None:
return check_result.fail("Divisor is not a constant value.")
if cst.const_value.numpy().size > 1:
return check_result.fail("Divisor is not a scalar value.")
return check_result

def rewrite(self, op, x, y, cst):
def rewrite(self, op, x, y, cst, fused: ir.Value):
value = cst.const_value.numpy()
c = float(value[0] if value.shape == (1,) else value)
node = list(x.uses())[0][0] # noqa: RUF015

kwargs = {}
alpha = node.attributes.get("alpha", None)
kwargs["alpha"] = alpha.value / c if alpha else 1.0 / c
for name in ["transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
fused_node = _get_node(fused, "FusedMatMul")
kwargs = _get_kwargs(fused_node)
kwargs["alpha"] = kwargs.get("alpha", 1.0) / c
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")


class _TransposeMatMulBase(orp.RewriteRuleClassBase):
_pos: ClassVar = 1

def check(self, context, x, y) -> orp.MatchResult:
def check(
self, context, x, y, transposed: ir.Value, fused: ir.Value | None = None, **_
) -> orp.MatchResult:
check_result = orp.MatchResult()
perm = list((x if self._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
if perm != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")
transposed_node = _get_node(transposed, "Transpose")
perm = _get_ints_or_default(transposed_node, "perm")
# If perm is not defined, the default transpose behavior is to swap
# the last two dimensions, which is the correct permutation.
if perm:
# Check that last two dimensions are swapped
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
if perm != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")
if fused:
fused_node = _get_node(fused, "FusedMatMul")
trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB"
if _get_int_or_default(fused_node, trans_batch_property):
return check_result.fail(

Check warning on line 111 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L111

Added line #L111 was not covered by tests
"FusedMatMul with transposed batch cannot be used with op.Transpose in this rule."
)
return check_result

def rewrite(self, op, x, y):
node = list((x if self._pos == 2 else y).uses())[0][0] # noqa: RUF015
def rewrite(self, op, x, y, fused: ir.Value | None = None, **_):
kwargs = {}
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
name = "transA" if self._pos == 1 else "transB"
kwargs[name] = 1 - kwargs.get(name, 0)
if fused:
fused_node = _get_node(fused, "FusedMatMul")
kwargs = _get_kwargs(fused_node)
trans_name = "transA" if self._pos == 1 else "transB"
kwargs[trans_name] = 1 - kwargs.get(trans_name, 0)
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")


class TransposeMatMul1(_TransposeMatMulBase):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
"""Replaces ``Transpose + MatMul`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.MatMul(op.Transpose(x), y)
return op.MatMul(op.Transpose(x, _outputs=["transposed"]), y)


class TransposeFusedMatMul1(TransposeMatMul1):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft")
return op.FusedMatMul(
op.Transpose(x, _outputs=["transposed"]),
y,
_domain="com.microsoft",
_outputs=["fused"],
)


class TransposeMatMul2(_TransposeMatMulBase):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
"""Replaces ``Transpose + MatMul`` with FusedMatMul."""

_pos: ClassVar = 2

def pattern(self, op, x, y):
return op.MatMul(x, op.Transpose(y))
return op.MatMul(x, op.Transpose(y, _outputs=["transposed"]))


class TransposeFusedMatMul2(TransposeMatMul2):
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft")
return op.FusedMatMul(
x,
op.Transpose(y, _outputs=["transposed"]),
_domain="com.microsoft",
_outputs=["fused"],
)


class _TransposeFusedMatMulBaseWithBatch(orp.RewriteRuleClassBase):
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul, either
when transBatchA or transBatchB in FusedMatMul is 1, or
can be inverted based on the permutation dims of the Transpose, in
contrast to the original FusedMatMul rule which assumes that
transBatchA and transBatchB are always 0 before and after rewriting.

transBatchA = 1, transA = 0 applies a batch transpose by moving the first dimension to the second-to-last position
i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-2, 0, N-1].
transBatchA = 0, transA = 1 flips the last two dimensions
i.e., equivalent to a Transpose with "perm" [0, 1, ... N-3, N-1, N-2].
transBatchA = 1, transA = 1 applies a batch transpose, then flips the last two dimensions
i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-1, 0].

The flipping logic is based on the following cases:
Case 1: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-1, 0]
or transBatchA is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2]
- Then transBatchA and transA can be flipped in FusedMatMul when rewriting.
Case 2: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1]
or transBatchA is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1]
- Then transBatchA can be flipped in FusedMatMul when rewriting.
Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0]
- Then transA can be flipped in FusedMatMul when rewriting.
The same logic applies for transBatchB and transB, when _pos is set to 2.
The _flip_transpose_batch and _flip_transpose flags are used to control
which case is applied by the rules of inheriting classes that change these class vars.
"""

_pos: ClassVar = 1
_flip_transpose_batch: ClassVar = False
_flip_transpose: ClassVar = False

def check(
self, context, x, y, transposed: ir.Value, fused: ir.Value, **_
) -> orp.MatchResult:
check_result = orp.MatchResult()
fused_node = _get_node(fused, "FusedMatMul")
trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB"
trans_batch = _get_int_or_default(fused_node, trans_batch_property)
transposed_node = _get_node(transposed, "Transpose")
perm = transposed_node.attributes["perm"].as_ints()
if not perm:
return check_result.fail("Permutation values for Transpose are not correct.")

Check warning on line 208 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L208

Added line #L208 was not covered by tests

list_perm = list(range(len(perm)))
if self._flip_transpose_batch and self._flip_transpose:
# Case 1: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-1, 0]
# or transBatchA/B is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2]
# - Then transBatchA/B and transA/B can be flipped in FusedMatMul when rewriting.
if trans_batch == 0:
expected_perm = [*list_perm[1:], list_perm[0]]
else:
expected_perm = [list_perm[-1], *list_perm[0:-1]]
if expected_perm == perm:
return check_result
elif self._flip_transpose_batch:
# Case 2: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1]
# or transBatchA/B is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1]
# - Then transBatchA/B can be flipped in FusedMatMul when rewriting.
if trans_batch == 0:
expected_perm = [*list_perm[1:-1], list_perm[0], list_perm[-1]]
else:
expected_perm = [list_perm[-2], *list_perm[0:-2], list_perm[-1]]
if expected_perm == perm:
return check_result
elif self._flip_transpose:
# Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0]
# - Then transA can be flipped in FusedMatMul when rewriting.
expected_perm = [list_perm[-1], *list_perm[1:-1], list_perm[0]]
if expected_perm == perm and trans_batch == 1:
return check_result

return check_result.fail("Permutation values for Transpose are not correct.")

def rewrite(self, op, x, y, fused: ir.Value, **_):
kwargs = {}
fused_node = _get_node(fused, "FusedMatMul")
kwargs = _get_kwargs(fused_node)
name = "A" if self._pos == 1 else "B"
if self._flip_transpose_batch:
trans_batch_property = f"transBatch{name}"
kwargs[trans_batch_property] = 1 - kwargs.get(trans_batch_property, 0)
if self._flip_transpose:
trans_property = f"trans{name}"
kwargs[trans_property] = 1 - kwargs.get(trans_property, 0)
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")

def pattern(self, op, x, y):
if self._pos == 1:
return op.FusedMatMul(
op.Transpose(x, _outputs=["transposed"]),
y,
_domain="com.microsoft",
_outputs=["fused"],
)
else:
return op.FusedMatMul(
x,
op.Transpose(y, _outputs=["transposed"]),
_domain="com.microsoft",
_outputs=["fused"],
)


TransposeFusedMatMulWithFlippedBatchAndTranspose1 = type(
"TransposeFusedMatMulWithFlippedBatchAndTranspose1",
(_TransposeFusedMatMulBaseWithBatch,),
{"_flip_transpose": True, "_flip_transpose_batch": True},
)
TransposeFusedMatMulWithFlippedBatchAndTranspose2 = type(
"TransposeFusedMatMulWithFlippedBatchAndTranspose2",
(_TransposeFusedMatMulBaseWithBatch,),
{"_pos": 2, "_flip_transpose": True, "_flip_transpose_batch": True},
)
TransposeFusedMatMulWithFlippedBatch1 = type(
"TransposeFusedMatMulWithFlippedBatch1",
(_TransposeFusedMatMulBaseWithBatch,),
{"_flip_transpose_batch": True},
)
TransposeFusedMatMulWithFlippedBatch2 = type(
"TransposeFusedMatMulWithFlippedBatch2",
(_TransposeFusedMatMulBaseWithBatch,),
{"_pos": 2, "_flip_transpose_batch": True},
)
TransposeFusedMatMulWithBatchAndTranspose1 = type(
"TransposeFusedMatMulWithBatchAndTranspose1",
(_TransposeFusedMatMulBaseWithBatch,),
{"_flip_transpose": True},
)
TransposeFusedMatMulWithBatchAndTranspose2 = type(
"TransposeFusedMatMulWithBatchAndTranspose2",
(_TransposeFusedMatMulBaseWithBatch,),
{"_pos": 2, "_flip_transpose": True},
)


class MatMulTranspose(orp.RewriteRuleClassBase):
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""
"""Replaces ``MatMul + Transpose`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.Transpose(op.MatMul(x, y))
return op.Transpose(op.MatMul(x, y), _outputs=["transposed"])

def check(self, context, x, y) -> orp.MatchResult:
def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
check_result = orp.MatchResult()
matmul = list(x.uses())[0][0] # noqa: RUF015
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015
perm = transpose.attributes["perm"].value
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
if perm != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")
transpose_node = _get_node(transposed, "Transpose")
perm = _get_ints_or_default(transpose_node, "perm")
# If perm is not defined, the default transpose behavior is to swap
# the last two dimensions, which is the correct permutation.
if perm:
# Check that last two dimensions are swapped
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
if perm != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")

Check warning on line 319 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L319

Added line #L319 was not covered by tests
return check_result

def rewrite(self, op, x, y):
node = list(x.uses())[0][0] # noqa: RUF015
def rewrite(self, op, x, y, fused: ir.Value | None = None, **_):
kwargs = {}
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
att = node.attributes.get(name)
if att:
kwargs[name] = att.value
if fused:
fused_node = _get_node(fused, "FusedMatMul")
kwargs = _get_kwargs(fused_node)
for name in ["transA", "transB"]:
kwargs[name] = 1 - kwargs.get(name, 0)
return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft")


class FusedMatMulTranspose(MatMulTranspose):
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""
"""Replaces ``FusedMatMul + Transpose`` with FusedMatMul."""

def pattern(self, op, x, y):
return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft"))
return op.Transpose(
op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]),
_outputs=["transposed"],
)


def fused_matmul_rule_sets() -> orp.RewriteRuleSet:
Expand All @@ -165,5 +357,11 @@
TransposeFusedMatMul1.rule(),
TransposeMatMul2.rule(),
TransposeFusedMatMul2.rule(),
TransposeFusedMatMulWithFlippedBatch1.rule(), # type: ignore[attr-defined]
TransposeFusedMatMulWithFlippedBatch2.rule(), # type: ignore[attr-defined]
TransposeFusedMatMulWithFlippedBatchAndTranspose1.rule(), # type: ignore[attr-defined]
TransposeFusedMatMulWithFlippedBatchAndTranspose2.rule(), # type: ignore[attr-defined]
TransposeFusedMatMulWithBatchAndTranspose1.rule(), # type: ignore[attr-defined]
TransposeFusedMatMulWithBatchAndTranspose2.rule(), # type: ignore[attr-defined]
]
)
Loading
Loading