Skip to content

Commit 5293005

Browse files
bmehta001justinchubytitaiwangmsCopilotgramalingam
authored
Fix fused matmul check/rewrite functions (#2331)
- Patterns now declare _outputs filters to bind intermediate values - Rewrites use fused.producer() or transposed.producer() instead of scanning .uses() which may pick up other nodes that use x or y - For ir.Value parameters, use a default of None in case the parameter does not exist - Attribute extraction updated to use as_float() / as_ints() for type safety - Since rewrite/check functions will have all ir.Value variables passed in, but they may not use all variables, use **_ to read in unused variables - Updated docstrings from "by" to "with" for clarity and changed fusedmatmul to matmul where appropriate - Add more patterns: 1. If Transpose.perm indices are [1:-1, 0, -1] and transBatchA is 0, we can change transBatchA to 1 2. If Transpose.perm indices are [-2, 0:-2, -1] and transBatchA is 1, we can change transBatchA to 0. 3. If Transpose.perm indices are [1:, 0] and transBatchA is 0, we can change transBatchA to 1 and transA to 1- transA 4. If Transpose.perm indices are [-1, 0:-1] and transBatchA is 1, we can change transBatchA to 0 and transA to 1- transA 5. If Transpose.perm indices are [-1, 1:-1, 0] and transBatchA is 1, we can change transA to 1- transA 6. And also do all of 1-5 for transBatchB - Add tests to make sure above changes work for `.producer()` and the added conditions related to `transBatch` All tests pass --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Justin Chu <[email protected]> Co-authored-by: Ti-Tai Wang <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: G. Ramalingam <[email protected]> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
1 parent af452c7 commit 5293005

File tree

2 files changed

+594
-324
lines changed

2 files changed

+594
-324
lines changed

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

Lines changed: 239 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,25 @@
55
from typing import ClassVar
66

77
import onnxscript.rewriter.pattern as orp
8+
from onnxscript import ir
9+
from onnxscript.rewriter import _ir_utils
10+
11+
12+
def _get_node(value: ir.Value, name: str) -> ir.Node:
13+
"""Get the node from the output value."""
14+
node = value.producer()
15+
assert node is not None, f"{name} node should not be None"
16+
return node
17+
18+
19+
def _get_kwargs(node: ir.Node) -> dict[str, float | int]:
20+
"""Get the kwargs from the node."""
21+
kwargs = {key: val.value for key, val in node.attributes.items()}
22+
return kwargs
823

924

1025
class FusedMatMulDiv1(orp.RewriteRuleClassBase):
11-
"""Replaces ``MatMul + Div`` by FusedMatMul."""
26+
"""Replaces ``MatMul + Div`` with MatMul."""
1227

1328
def pattern(self, op, x, y, cst):
1429
return op.Div(op.MatMul(x, y), cst)
@@ -29,122 +44,286 @@ def rewrite(self, op, x, y, cst):
2944

3045

3146
class FusedMatMulDiv2(orp.RewriteRuleClassBase):
32-
"""Replaces ``FusedMatMul + Div`` by FusedMatMul."""
47+
"""Replaces ``FusedMatMul + Div`` with FusedMatMul."""
3348

3449
def pattern(self, op, x, y, cst):
35-
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst)
50+
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), cst)
3651

37-
def check(self, context, x, y, cst) -> orp.MatchResult:
52+
def check(self, context, x, y, cst, **_) -> orp.MatchResult:
3853
check_result = orp.MatchResult()
3954
if cst.const_value is None:
4055
return check_result.fail("Divisor is not a constant value.")
4156
if cst.const_value.numpy().size > 1:
4257
return check_result.fail("Divisor is not a scalar value.")
4358
return check_result
4459

45-
def rewrite(self, op, x, y, cst):
60+
def rewrite(self, op, x, y, cst, fused: ir.Value):
4661
value = cst.const_value.numpy()
4762
c = float(value[0] if value.shape == (1,) else value)
48-
node = list(x.uses())[0][0] # noqa: RUF015
49-
50-
kwargs = {}
51-
alpha = node.attributes.get("alpha", None)
52-
kwargs["alpha"] = alpha.value / c if alpha else 1.0 / c
53-
for name in ["transA", "transB", "transBatchA", "transBatchB"]:
54-
att = node.attributes.get(name)
55-
if att:
56-
kwargs[name] = att.value
63+
fused_node = _get_node(fused, "FusedMatMul")
64+
kwargs = _get_kwargs(fused_node)
65+
kwargs["alpha"] = kwargs.get("alpha", 1.0) / c
5766
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")
5867

5968

6069
class _TransposeMatMulBase(orp.RewriteRuleClassBase):
6170
_pos: ClassVar = 1
6271

63-
def check(self, context, x, y) -> orp.MatchResult:
72+
def check(
73+
self, context, x, y, transposed: ir.Value, fused: ir.Value | None = None, **_
74+
) -> orp.MatchResult:
6475
check_result = orp.MatchResult()
65-
perm = list((x if self._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015
66-
expected_perm = list(range(len(perm)))
67-
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
68-
if perm != expected_perm:
69-
return check_result.fail("Permutation values for Transpose are not correct.")
76+
transposed_node = _get_node(transposed, "Transpose")
77+
perm = transposed_node.attributes.get_ints("perm")
78+
if perm:
79+
# Check that last two dimensions are swapped
80+
expected_perm = list(range(len(perm)))
81+
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
82+
if perm != expected_perm:
83+
return check_result.fail("Permutation values for Transpose are not correct.")
84+
elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or (
85+
self._pos == 2 and not _ir_utils.has_rank(y, 2)
86+
):
87+
# If perm is not defined, the default transpose behavior is to swap
88+
# all dimensions, which is correct for MatMul with rank = 2.
89+
return check_result.fail(
90+
"If perm is not defined, rank must be 2 for TransposeMatMul rule."
91+
)
92+
if fused:
93+
fused_node = _get_node(fused, "FusedMatMul")
94+
trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB"
95+
if fused_node.attributes.get_int(trans_batch_property, 0):
96+
return check_result.fail(
97+
"FusedMatMul with transposed batch cannot be used with op.Transpose in this rule."
98+
)
7099
return check_result
71100

72-
def rewrite(self, op, x, y):
73-
node = list((x if self._pos == 2 else y).uses())[0][0] # noqa: RUF015
101+
def rewrite(self, op, x, y, fused: ir.Value | None = None, **_):
74102
kwargs = {}
75-
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
76-
att = node.attributes.get(name)
77-
if att:
78-
kwargs[name] = att.value
79-
name = "transA" if self._pos == 1 else "transB"
80-
kwargs[name] = 1 - kwargs.get(name, 0)
103+
if fused:
104+
fused_node = _get_node(fused, "FusedMatMul")
105+
kwargs = _get_kwargs(fused_node)
106+
trans_name = "transA" if self._pos == 1 else "transB"
107+
kwargs[trans_name] = 1 - kwargs.get(trans_name, 0)
81108
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")
82109

83110

84111
class TransposeMatMul1(_TransposeMatMulBase):
85-
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
112+
"""Replaces ``Transpose + MatMul`` with FusedMatMul."""
86113

87114
def pattern(self, op, x, y):
88-
return op.MatMul(op.Transpose(x), y)
115+
return op.MatMul(op.Transpose(x, _outputs=["transposed"]), y)
89116

90117

91118
class TransposeFusedMatMul1(TransposeMatMul1):
92-
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
119+
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul."""
93120

94121
def pattern(self, op, x, y):
95-
return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft")
122+
return op.FusedMatMul(
123+
op.Transpose(x, _outputs=["transposed"]),
124+
y,
125+
_domain="com.microsoft",
126+
_outputs=["fused"],
127+
)
96128

97129

98130
class TransposeMatMul2(_TransposeMatMulBase):
99-
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
131+
"""Replaces ``Transpose + MatMul`` with FusedMatMul."""
100132

101133
_pos: ClassVar = 2
102134

103135
def pattern(self, op, x, y):
104-
return op.MatMul(x, op.Transpose(y))
136+
return op.MatMul(x, op.Transpose(y, _outputs=["transposed"]))
105137

106138

107139
class TransposeFusedMatMul2(TransposeMatMul2):
108-
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
140+
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul."""
109141

110142
def pattern(self, op, x, y):
111-
return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft")
143+
return op.FusedMatMul(
144+
x,
145+
op.Transpose(y, _outputs=["transposed"]),
146+
_domain="com.microsoft",
147+
_outputs=["fused"],
148+
)
149+
150+
151+
class _TransposeFusedMatMulBaseWithBatch(orp.RewriteRuleClassBase):
152+
"""Replaces ``Transpose + FusedMatMul`` with FusedMatMul, either
153+
when transBatchA or transBatchB in FusedMatMul is 1, or
154+
can be inverted based on the permutation dims of the Transpose, in
155+
contrast to the original FusedMatMul rule which assumes that
156+
transBatchA and transBatchB are always 0 before and after rewriting.
157+
158+
transBatchA = 1, transA = 0 applies a batch transpose by moving the first dimension to the second-to-last position
159+
i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-2, 0, N-1].
160+
transBatchA = 0, transA = 1 flips the last two dimensions
161+
i.e., equivalent to a Transpose with "perm" [0, 1, ... N-3, N-1, N-2].
162+
transBatchA = 1, transA = 1 applies a batch transpose, then flips the last two dimensions
163+
i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-1, 0].
164+
165+
The flipping logic is based on the following cases:
166+
Case 1: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-1, 0]
167+
or transBatchA is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2]
168+
- Then transBatchA and transA can be flipped in FusedMatMul when rewriting.
169+
Case 2: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1]
170+
or transBatchA is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1]
171+
- Then transBatchA can be flipped in FusedMatMul when rewriting.
172+
Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0]
173+
- Then transA can be flipped in FusedMatMul when rewriting.
174+
The same logic applies for transBatchB and transB, when _pos is set to 2.
175+
The _flip_transpose_batch and _flip_transpose flags are used to control
176+
which case is applied by the rules of inheriting classes that change these class vars.
177+
"""
178+
179+
_pos: ClassVar = 1
180+
_flip_transpose_batch: ClassVar = False
181+
_flip_transpose: ClassVar = False
182+
183+
def check(
184+
self, context, x, y, transposed: ir.Value, fused: ir.Value, **_
185+
) -> orp.MatchResult:
186+
check_result = orp.MatchResult()
187+
fused_node = _get_node(fused, "FusedMatMul")
188+
trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB"
189+
trans_batch = fused_node.attributes.get_int(trans_batch_property, 0)
190+
transposed_node = _get_node(transposed, "Transpose")
191+
perm = transposed_node.attributes["perm"].as_ints()
192+
if not perm:
193+
return check_result.fail("Permutation values for Transpose are not correct.")
194+
195+
list_perm = list(range(len(perm)))
196+
if self._flip_transpose_batch and self._flip_transpose:
197+
# Case 1: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-1, 0]
198+
# or transBatchA/B is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2]
199+
# - Then transBatchA/B and transA/B can be flipped in FusedMatMul when rewriting.
200+
if trans_batch == 0:
201+
expected_perm = [*list_perm[1:], list_perm[0]]
202+
else:
203+
expected_perm = [list_perm[-1], *list_perm[0:-1]]
204+
if expected_perm == perm:
205+
return check_result
206+
elif self._flip_transpose_batch:
207+
# Case 2: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1]
208+
# or transBatchA/B is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1]
209+
# - Then transBatchA/B can be flipped in FusedMatMul when rewriting.
210+
if trans_batch == 0:
211+
expected_perm = [*list_perm[1:-1], list_perm[0], list_perm[-1]]
212+
else:
213+
expected_perm = [list_perm[-2], *list_perm[0:-2], list_perm[-1]]
214+
if expected_perm == perm:
215+
return check_result
216+
elif self._flip_transpose:
217+
# Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0]
218+
# - Then transA can be flipped in FusedMatMul when rewriting.
219+
expected_perm = [list_perm[-1], *list_perm[1:-1], list_perm[0]]
220+
if expected_perm == perm and trans_batch == 1:
221+
return check_result
222+
223+
return check_result.fail("Permutation values for Transpose are not correct.")
224+
225+
def rewrite(self, op, x, y, fused: ir.Value, **_):
226+
kwargs = {}
227+
fused_node = _get_node(fused, "FusedMatMul")
228+
kwargs = _get_kwargs(fused_node)
229+
name = "A" if self._pos == 1 else "B"
230+
if self._flip_transpose_batch:
231+
trans_batch_property = f"transBatch{name}"
232+
kwargs[trans_batch_property] = 1 - kwargs.get(trans_batch_property, 0)
233+
if self._flip_transpose:
234+
trans_property = f"trans{name}"
235+
kwargs[trans_property] = 1 - kwargs.get(trans_property, 0)
236+
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")
237+
238+
def pattern(self, op, x, y):
239+
if self._pos == 1:
240+
return op.FusedMatMul(
241+
op.Transpose(x, _outputs=["transposed"]),
242+
y,
243+
_domain="com.microsoft",
244+
_outputs=["fused"],
245+
)
246+
else:
247+
return op.FusedMatMul(
248+
x,
249+
op.Transpose(y, _outputs=["transposed"]),
250+
_domain="com.microsoft",
251+
_outputs=["fused"],
252+
)
253+
254+
255+
class TransposeFusedMatMulWithFlippedBatchAndTranspose1(_TransposeFusedMatMulBaseWithBatch):
256+
_flip_transpose = True
257+
_flip_transpose_batch = True
258+
259+
260+
class TransposeFusedMatMulWithFlippedBatchAndTranspose2(_TransposeFusedMatMulBaseWithBatch):
261+
_pos = 2
262+
_flip_transpose = True
263+
_flip_transpose_batch = True
264+
265+
266+
class TransposeFusedMatMulWithFlippedBatch1(_TransposeFusedMatMulBaseWithBatch):
267+
_flip_transpose_batch = True
268+
269+
270+
class TransposeFusedMatMulWithFlippedBatch2(_TransposeFusedMatMulBaseWithBatch):
271+
_pos = 2
272+
_flip_transpose_batch = True
273+
274+
275+
class TransposeFusedMatMulWithBatchAndTranspose1(_TransposeFusedMatMulBaseWithBatch):
276+
_flip_transpose = True
277+
278+
279+
class TransposeFusedMatMulWithBatchAndTranspose2(_TransposeFusedMatMulBaseWithBatch):
280+
_pos = 2
281+
_flip_transpose = True
112282

113283

114284
class MatMulTranspose(orp.RewriteRuleClassBase):
115-
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""
285+
"""Replaces ``MatMul + Transpose`` with FusedMatMul."""
116286

117287
def pattern(self, op, x, y):
118-
return op.Transpose(op.MatMul(x, y))
288+
return op.Transpose(op.MatMul(x, y), _outputs=["transposed"])
119289

120-
def check(self, context, x, y) -> orp.MatchResult:
290+
def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
121291
check_result = orp.MatchResult()
122-
matmul = list(x.uses())[0][0] # noqa: RUF015
123-
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015
124-
perm = transpose.attributes["perm"].value
125-
expected_perm = list(range(len(perm)))
126-
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
127-
if perm != expected_perm:
128-
return check_result.fail("Permutation values for Transpose are not correct.")
292+
transpose_node = _get_node(transposed, "Transpose")
293+
perm = transpose_node.attributes.get_ints("perm")
294+
# transA/transB only work on the last two dimensions of the input,
295+
# so we can only apply this rule if the inputs are rank 2.
296+
if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2):
297+
if perm:
298+
# Check that the two dimensions are swapped
299+
if perm != [1, 0]:
300+
return check_result.fail(
301+
"Permutation values for Transpose are not correct."
302+
)
303+
# If perm is not defined, the default transpose behavior is to swap
304+
# all dimensions, which is correct for MatMul with rank = 2.
305+
else:
306+
return check_result.fail("Rank must be 2 for MatMulTranspose rule.")
129307
return check_result
130308

131-
def rewrite(self, op, x, y):
132-
node = list(x.uses())[0][0] # noqa: RUF015
309+
def rewrite(self, op, x, y, fused: ir.Value | None = None, **_):
133310
kwargs = {}
134-
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
135-
att = node.attributes.get(name)
136-
if att:
137-
kwargs[name] = att.value
311+
if fused:
312+
fused_node = _get_node(fused, "FusedMatMul")
313+
kwargs = _get_kwargs(fused_node)
138314
for name in ["transA", "transB"]:
139315
kwargs[name] = 1 - kwargs.get(name, 0)
140316
return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft")
141317

142318

143319
class FusedMatMulTranspose(MatMulTranspose):
144-
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""
320+
"""Replaces ``FusedMatMul + Transpose`` with FusedMatMul."""
145321

146322
def pattern(self, op, x, y):
147-
return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft"))
323+
return op.Transpose(
324+
op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]),
325+
_outputs=["transposed"],
326+
)
148327

149328

150329
def fused_matmul_rule_sets() -> orp.RewriteRuleSet:
@@ -165,5 +344,11 @@ def fused_matmul_rule_sets() -> orp.RewriteRuleSet:
165344
TransposeFusedMatMul1.rule(),
166345
TransposeMatMul2.rule(),
167346
TransposeFusedMatMul2.rule(),
347+
TransposeFusedMatMulWithFlippedBatch1.rule(),
348+
TransposeFusedMatMulWithFlippedBatch2.rule(),
349+
TransposeFusedMatMulWithFlippedBatchAndTranspose1.rule(),
350+
TransposeFusedMatMulWithFlippedBatchAndTranspose2.rule(),
351+
TransposeFusedMatMulWithBatchAndTranspose1.rule(),
352+
TransposeFusedMatMulWithBatchAndTranspose2.rule(),
168353
]
169354
)

0 commit comments

Comments
 (0)