5
5
from typing import ClassVar
6
6
7
7
import onnxscript .rewriter .pattern as orp
8
+ from onnxscript import ir
9
+ from onnxscript .rewriter import _ir_utils
10
+
11
+
12
+ def _get_node (value : ir .Value , name : str ) -> ir .Node :
13
+ """Get the node from the output value."""
14
+ node = value .producer ()
15
+ assert node is not None , f"{ name } node should not be None"
16
+ return node
17
+
18
+
19
+ def _get_kwargs (node : ir .Node ) -> dict [str , float | int ]:
20
+ """Get the kwargs from the node."""
21
+ kwargs = {key : val .value for key , val in node .attributes .items ()}
22
+ return kwargs
8
23
9
24
10
25
class FusedMatMulDiv1 (orp .RewriteRuleClassBase ):
11
- """Replaces ``MatMul + Div`` by FusedMatMul ."""
26
+ """Replaces ``MatMul + Div`` with MatMul ."""
12
27
13
28
def pattern (self , op , x , y , cst ):
14
29
return op .Div (op .MatMul (x , y ), cst )
@@ -29,122 +44,286 @@ def rewrite(self, op, x, y, cst):
29
44
30
45
31
46
class FusedMatMulDiv2 (orp .RewriteRuleClassBase ):
32
- """Replaces ``FusedMatMul + Div`` by FusedMatMul."""
47
+ """Replaces ``FusedMatMul + Div`` with FusedMatMul."""
33
48
34
49
def pattern (self , op , x , y , cst ):
35
- return op .Div (op .FusedMatMul (x , y , _domain = "com.microsoft" ), cst )
50
+ return op .Div (op .FusedMatMul (x , y , _domain = "com.microsoft" , _outputs = [ "fused" ] ), cst )
36
51
37
- def check (self , context , x , y , cst ) -> orp .MatchResult :
52
+ def check (self , context , x , y , cst , ** _ ) -> orp .MatchResult :
38
53
check_result = orp .MatchResult ()
39
54
if cst .const_value is None :
40
55
return check_result .fail ("Divisor is not a constant value." )
41
56
if cst .const_value .numpy ().size > 1 :
42
57
return check_result .fail ("Divisor is not a scalar value." )
43
58
return check_result
44
59
45
- def rewrite (self , op , x , y , cst ):
60
+ def rewrite (self , op , x , y , cst , fused : ir . Value ):
46
61
value = cst .const_value .numpy ()
47
62
c = float (value [0 ] if value .shape == (1 ,) else value )
48
- node = list (x .uses ())[0 ][0 ] # noqa: RUF015
49
-
50
- kwargs = {}
51
- alpha = node .attributes .get ("alpha" , None )
52
- kwargs ["alpha" ] = alpha .value / c if alpha else 1.0 / c
53
- for name in ["transA" , "transB" , "transBatchA" , "transBatchB" ]:
54
- att = node .attributes .get (name )
55
- if att :
56
- kwargs [name ] = att .value
63
+ fused_node = _get_node (fused , "FusedMatMul" )
64
+ kwargs = _get_kwargs (fused_node )
65
+ kwargs ["alpha" ] = kwargs .get ("alpha" , 1.0 ) / c
57
66
return op .FusedMatMul (x , y , ** kwargs , _domain = "com.microsoft" )
58
67
59
68
60
69
class _TransposeMatMulBase (orp .RewriteRuleClassBase ):
61
70
_pos : ClassVar = 1
62
71
63
- def check (self , context , x , y ) -> orp .MatchResult :
72
+ def check (
73
+ self , context , x , y , transposed : ir .Value , fused : ir .Value | None = None , ** _
74
+ ) -> orp .MatchResult :
64
75
check_result = orp .MatchResult ()
65
- perm = list ((x if self ._pos == 1 else y ).uses ())[0 ][0 ].attributes ["perm" ].value # noqa: RUF015
66
- expected_perm = list (range (len (perm )))
67
- expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
68
- if perm != expected_perm :
69
- return check_result .fail ("Permutation values for Transpose are not correct." )
76
+ transposed_node = _get_node (transposed , "Transpose" )
77
+ perm = transposed_node .attributes .get_ints ("perm" )
78
+ if perm :
79
+ # Check that last two dimensions are swapped
80
+ expected_perm = list (range (len (perm )))
81
+ expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
82
+ if perm != expected_perm :
83
+ return check_result .fail ("Permutation values for Transpose are not correct." )
84
+ elif (self ._pos == 1 and not _ir_utils .has_rank (x , 2 )) or (
85
+ self ._pos == 2 and not _ir_utils .has_rank (y , 2 )
86
+ ):
87
+ # If perm is not defined, the default transpose behavior is to swap
88
+ # all dimensions, which is correct for MatMul with rank = 2.
89
+ return check_result .fail (
90
+ "If perm is not defined, rank must be 2 for TransposeMatMul rule."
91
+ )
92
+ if fused :
93
+ fused_node = _get_node (fused , "FusedMatMul" )
94
+ trans_batch_property = "transBatchA" if self ._pos == 1 else "transBatchB"
95
+ if fused_node .attributes .get_int (trans_batch_property , 0 ):
96
+ return check_result .fail (
97
+ "FusedMatMul with transposed batch cannot be used with op.Transpose in this rule."
98
+ )
70
99
return check_result
71
100
72
- def rewrite (self , op , x , y ):
73
- node = list ((x if self ._pos == 2 else y ).uses ())[0 ][0 ] # noqa: RUF015
101
+ def rewrite (self , op , x , y , fused : ir .Value | None = None , ** _ ):
74
102
kwargs = {}
75
- for name in ["alpha" , "transA" , "transB" , "transBatchA" , "transBatchB" ]:
76
- att = node .attributes .get (name )
77
- if att :
78
- kwargs [name ] = att .value
79
- name = "transA" if self ._pos == 1 else "transB"
80
- kwargs [name ] = 1 - kwargs .get (name , 0 )
103
+ if fused :
104
+ fused_node = _get_node (fused , "FusedMatMul" )
105
+ kwargs = _get_kwargs (fused_node )
106
+ trans_name = "transA" if self ._pos == 1 else "transB"
107
+ kwargs [trans_name ] = 1 - kwargs .get (trans_name , 0 )
81
108
return op .FusedMatMul (x , y , ** kwargs , _domain = "com.microsoft" )
82
109
83
110
84
111
class TransposeMatMul1 (_TransposeMatMulBase ):
85
- """Replaces ``Transpose + (Fused) MatMul`` by FusedMatMul."""
112
+ """Replaces ``Transpose + MatMul`` with FusedMatMul."""
86
113
87
114
def pattern (self , op , x , y ):
88
- return op .MatMul (op .Transpose (x ), y )
115
+ return op .MatMul (op .Transpose (x , _outputs = [ "transposed" ] ), y )
89
116
90
117
91
118
class TransposeFusedMatMul1 (TransposeMatMul1 ):
92
- """Replaces ``Transpose + (Fused)MatMul `` by FusedMatMul."""
119
+ """Replaces ``Transpose + FusedMatMul `` with FusedMatMul."""
93
120
94
121
def pattern (self , op , x , y ):
95
- return op .FusedMatMul (op .Transpose (x ), y , _domain = "com.microsoft" )
122
+ return op .FusedMatMul (
123
+ op .Transpose (x , _outputs = ["transposed" ]),
124
+ y ,
125
+ _domain = "com.microsoft" ,
126
+ _outputs = ["fused" ],
127
+ )
96
128
97
129
98
130
class TransposeMatMul2 (_TransposeMatMulBase ):
99
- """Replaces ``Transpose + (Fused) MatMul`` by FusedMatMul."""
131
+ """Replaces ``Transpose + MatMul`` with FusedMatMul."""
100
132
101
133
_pos : ClassVar = 2
102
134
103
135
def pattern (self , op , x , y ):
104
- return op .MatMul (x , op .Transpose (y ))
136
+ return op .MatMul (x , op .Transpose (y , _outputs = [ "transposed" ] ))
105
137
106
138
107
139
class TransposeFusedMatMul2 (TransposeMatMul2 ):
108
- """Replaces ``Transpose + (Fused)MatMul `` by FusedMatMul."""
140
+ """Replaces ``Transpose + FusedMatMul `` with FusedMatMul."""
109
141
110
142
def pattern (self , op , x , y ):
111
- return op .FusedMatMul (x , op .Transpose (y ), _domain = "com.microsoft" )
143
+ return op .FusedMatMul (
144
+ x ,
145
+ op .Transpose (y , _outputs = ["transposed" ]),
146
+ _domain = "com.microsoft" ,
147
+ _outputs = ["fused" ],
148
+ )
149
+
150
+
151
+ class _TransposeFusedMatMulBaseWithBatch (orp .RewriteRuleClassBase ):
152
+ """Replaces ``Transpose + FusedMatMul`` with FusedMatMul, either
153
+ when transBatchA or transBatchB in FusedMatMul is 1, or
154
+ can be inverted based on the permutation dims of the Transpose, in
155
+ contrast to the original FusedMatMul rule which assumes that
156
+ transBatchA and transBatchB are always 0 before and after rewriting.
157
+
158
+ transBatchA = 1, transA = 0 applies a batch transpose by moving the first dimension to the second-to-last position
159
+ i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-2, 0, N-1].
160
+ transBatchA = 0, transA = 1 flips the last two dimensions
161
+ i.e., equivalent to a Transpose with "perm" [0, 1, ... N-3, N-1, N-2].
162
+ transBatchA = 1, transA = 1 applies a batch transpose, then flips the last two dimensions
163
+ i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-1, 0].
164
+
165
+ The flipping logic is based on the following cases:
166
+ Case 1: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-1, 0]
167
+ or transBatchA is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2]
168
+ - Then transBatchA and transA can be flipped in FusedMatMul when rewriting.
169
+ Case 2: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1]
170
+ or transBatchA is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1]
171
+ - Then transBatchA can be flipped in FusedMatMul when rewriting.
172
+ Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0]
173
+ - Then transA can be flipped in FusedMatMul when rewriting.
174
+ The same logic applies for transBatchB and transB, when _pos is set to 2.
175
+ The _flip_transpose_batch and _flip_transpose flags are used to control
176
+ which case is applied by the rules of inheriting classes that change these class vars.
177
+ """
178
+
179
+ _pos : ClassVar = 1
180
+ _flip_transpose_batch : ClassVar = False
181
+ _flip_transpose : ClassVar = False
182
+
183
+ def check (
184
+ self , context , x , y , transposed : ir .Value , fused : ir .Value , ** _
185
+ ) -> orp .MatchResult :
186
+ check_result = orp .MatchResult ()
187
+ fused_node = _get_node (fused , "FusedMatMul" )
188
+ trans_batch_property = "transBatchA" if self ._pos == 1 else "transBatchB"
189
+ trans_batch = fused_node .attributes .get_int (trans_batch_property , 0 )
190
+ transposed_node = _get_node (transposed , "Transpose" )
191
+ perm = transposed_node .attributes ["perm" ].as_ints ()
192
+ if not perm :
193
+ return check_result .fail ("Permutation values for Transpose are not correct." )
194
+
195
+ list_perm = list (range (len (perm )))
196
+ if self ._flip_transpose_batch and self ._flip_transpose :
197
+ # Case 1: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-1, 0]
198
+ # or transBatchA/B is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2]
199
+ # - Then transBatchA/B and transA/B can be flipped in FusedMatMul when rewriting.
200
+ if trans_batch == 0 :
201
+ expected_perm = [* list_perm [1 :], list_perm [0 ]]
202
+ else :
203
+ expected_perm = [list_perm [- 1 ], * list_perm [0 :- 1 ]]
204
+ if expected_perm == perm :
205
+ return check_result
206
+ elif self ._flip_transpose_batch :
207
+ # Case 2: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1]
208
+ # or transBatchA/B is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1]
209
+ # - Then transBatchA/B can be flipped in FusedMatMul when rewriting.
210
+ if trans_batch == 0 :
211
+ expected_perm = [* list_perm [1 :- 1 ], list_perm [0 ], list_perm [- 1 ]]
212
+ else :
213
+ expected_perm = [list_perm [- 2 ], * list_perm [0 :- 2 ], list_perm [- 1 ]]
214
+ if expected_perm == perm :
215
+ return check_result
216
+ elif self ._flip_transpose :
217
+ # Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0]
218
+ # - Then transA can be flipped in FusedMatMul when rewriting.
219
+ expected_perm = [list_perm [- 1 ], * list_perm [1 :- 1 ], list_perm [0 ]]
220
+ if expected_perm == perm and trans_batch == 1 :
221
+ return check_result
222
+
223
+ return check_result .fail ("Permutation values for Transpose are not correct." )
224
+
225
+ def rewrite (self , op , x , y , fused : ir .Value , ** _ ):
226
+ kwargs = {}
227
+ fused_node = _get_node (fused , "FusedMatMul" )
228
+ kwargs = _get_kwargs (fused_node )
229
+ name = "A" if self ._pos == 1 else "B"
230
+ if self ._flip_transpose_batch :
231
+ trans_batch_property = f"transBatch{ name } "
232
+ kwargs [trans_batch_property ] = 1 - kwargs .get (trans_batch_property , 0 )
233
+ if self ._flip_transpose :
234
+ trans_property = f"trans{ name } "
235
+ kwargs [trans_property ] = 1 - kwargs .get (trans_property , 0 )
236
+ return op .FusedMatMul (x , y , ** kwargs , _domain = "com.microsoft" )
237
+
238
+ def pattern (self , op , x , y ):
239
+ if self ._pos == 1 :
240
+ return op .FusedMatMul (
241
+ op .Transpose (x , _outputs = ["transposed" ]),
242
+ y ,
243
+ _domain = "com.microsoft" ,
244
+ _outputs = ["fused" ],
245
+ )
246
+ else :
247
+ return op .FusedMatMul (
248
+ x ,
249
+ op .Transpose (y , _outputs = ["transposed" ]),
250
+ _domain = "com.microsoft" ,
251
+ _outputs = ["fused" ],
252
+ )
253
+
254
+
255
+ class TransposeFusedMatMulWithFlippedBatchAndTranspose1 (_TransposeFusedMatMulBaseWithBatch ):
256
+ _flip_transpose = True
257
+ _flip_transpose_batch = True
258
+
259
+
260
+ class TransposeFusedMatMulWithFlippedBatchAndTranspose2 (_TransposeFusedMatMulBaseWithBatch ):
261
+ _pos = 2
262
+ _flip_transpose = True
263
+ _flip_transpose_batch = True
264
+
265
+
266
+ class TransposeFusedMatMulWithFlippedBatch1 (_TransposeFusedMatMulBaseWithBatch ):
267
+ _flip_transpose_batch = True
268
+
269
+
270
+ class TransposeFusedMatMulWithFlippedBatch2 (_TransposeFusedMatMulBaseWithBatch ):
271
+ _pos = 2
272
+ _flip_transpose_batch = True
273
+
274
+
275
+ class TransposeFusedMatMulWithBatchAndTranspose1 (_TransposeFusedMatMulBaseWithBatch ):
276
+ _flip_transpose = True
277
+
278
+
279
+ class TransposeFusedMatMulWithBatchAndTranspose2 (_TransposeFusedMatMulBaseWithBatch ):
280
+ _pos = 2
281
+ _flip_transpose = True
112
282
113
283
114
284
class MatMulTranspose (orp .RewriteRuleClassBase ):
115
- """Replaces ``MatMul + Transpose`` by FusedMatMul."""
285
+ """Replaces ``MatMul + Transpose`` with FusedMatMul."""
116
286
117
287
def pattern (self , op , x , y ):
118
- return op .Transpose (op .MatMul (x , y ))
288
+ return op .Transpose (op .MatMul (x , y ), _outputs = [ "transposed" ] )
119
289
120
- def check (self , context , x , y ) -> orp .MatchResult :
290
+ def check (self , context , x , y , transposed : ir . Value , ** _ ) -> orp .MatchResult :
121
291
check_result = orp .MatchResult ()
122
- matmul = list (x .uses ())[0 ][0 ] # noqa: RUF015
123
- transpose = list (matmul .outputs [0 ].uses ())[0 ][0 ] # noqa: RUF015
124
- perm = transpose .attributes ["perm" ].value
125
- expected_perm = list (range (len (perm )))
126
- expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
127
- if perm != expected_perm :
128
- return check_result .fail ("Permutation values for Transpose are not correct." )
292
+ transpose_node = _get_node (transposed , "Transpose" )
293
+ perm = transpose_node .attributes .get_ints ("perm" )
294
+ # transA/transB only work on the last two dimensions of the input,
295
+ # so we can only apply this rule if the inputs are rank 2.
296
+ if _ir_utils .has_rank (x , 2 ) and _ir_utils .has_rank (y , 2 ):
297
+ if perm :
298
+ # Check that the two dimensions are swapped
299
+ if perm != [1 , 0 ]:
300
+ return check_result .fail (
301
+ "Permutation values for Transpose are not correct."
302
+ )
303
+ # If perm is not defined, the default transpose behavior is to swap
304
+ # all dimensions, which is correct for MatMul with rank = 2.
305
+ else :
306
+ return check_result .fail ("Rank must be 2 for MatMulTranspose rule." )
129
307
return check_result
130
308
131
- def rewrite (self , op , x , y ):
132
- node = list (x .uses ())[0 ][0 ] # noqa: RUF015
309
+ def rewrite (self , op , x , y , fused : ir .Value | None = None , ** _ ):
133
310
kwargs = {}
134
- for name in ["alpha" , "transA" , "transB" , "transBatchA" , "transBatchB" ]:
135
- att = node .attributes .get (name )
136
- if att :
137
- kwargs [name ] = att .value
311
+ if fused :
312
+ fused_node = _get_node (fused , "FusedMatMul" )
313
+ kwargs = _get_kwargs (fused_node )
138
314
for name in ["transA" , "transB" ]:
139
315
kwargs [name ] = 1 - kwargs .get (name , 0 )
140
316
return op .FusedMatMul (y , x , ** kwargs , _domain = "com.microsoft" )
141
317
142
318
143
319
class FusedMatMulTranspose (MatMulTranspose ):
144
- """Replaces ``MatMul + Transpose`` by FusedMatMul."""
320
+ """Replaces ``FusedMatMul + Transpose`` with FusedMatMul."""
145
321
146
322
def pattern (self , op , x , y ):
147
- return op .Transpose (op .FusedMatMul (x , y , _domain = "com.microsoft" ))
323
+ return op .Transpose (
324
+ op .FusedMatMul (x , y , _domain = "com.microsoft" , _outputs = ["fused" ]),
325
+ _outputs = ["transposed" ],
326
+ )
148
327
149
328
150
329
def fused_matmul_rule_sets () -> orp .RewriteRuleSet :
@@ -165,5 +344,11 @@ def fused_matmul_rule_sets() -> orp.RewriteRuleSet:
165
344
TransposeFusedMatMul1 .rule (),
166
345
TransposeMatMul2 .rule (),
167
346
TransposeFusedMatMul2 .rule (),
347
+ TransposeFusedMatMulWithFlippedBatch1 .rule (),
348
+ TransposeFusedMatMulWithFlippedBatch2 .rule (),
349
+ TransposeFusedMatMulWithFlippedBatchAndTranspose1 .rule (),
350
+ TransposeFusedMatMulWithFlippedBatchAndTranspose2 .rule (),
351
+ TransposeFusedMatMulWithBatchAndTranspose1 .rule (),
352
+ TransposeFusedMatMulWithBatchAndTranspose2 .rule (),
168
353
]
169
354
)
0 commit comments