11
11
from onnxscript .rewriter import pattern as orp
12
12
13
13
14
+ class SqueezeReshape (orp .RewriteRuleClassBase ):
15
+ """Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.
16
+
17
+ This pattern arises from the translation of pytorch symints.
18
+ """
19
+
20
+ def __init__ (self ):
21
+ super ().__init__ ("SqueezeReshape1d" , remove_nodes = False )
22
+
23
+ def pattern (self , op , x ):
24
+ return op .Reshape (op .Squeeze (x ), [- 1 ])
25
+
26
+ def rewrite (self , op , x : ir .Value ):
27
+ return op .Identity (x )
28
+
29
+ def check (self , context , x ) -> orp .MatchResult :
30
+ del context # Unused
31
+ check_result = orp .MatchResult ()
32
+ if not ir_utils .has_rank (x , 1 ):
33
+ return check_result .fail ("Input is not 1D" )
34
+ return check_result
35
+
36
+
14
37
class CastIdentity (orp .RewriteRuleAsClass ):
15
38
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
16
39
@@ -23,8 +46,11 @@ def rewrite(cls, op, x: ir.Value, to: ir.Attr):
23
46
return op .Identity (x )
24
47
25
48
@classmethod
26
- def check (cls , context , x , to ) -> bool :
27
- return x .dtype == to .value
49
+ def check (cls , context , x , to ) -> orp .MatchResult :
50
+ check_result = orp .MatchResult ()
51
+ if x .dtype != to .value :
52
+ return check_result .fail ("Input and output types are not the same" )
53
+ return check_result
28
54
29
55
30
56
class CastCast (orp .RewriteRuleAsClass ):
@@ -42,11 +68,13 @@ def pattern(cls, op, x, to, to_ignored):
42
68
return op .Cast (op .Cast (x , to = to_ignored ), to = to )
43
69
44
70
@classmethod
45
- def check (cls , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> bool :
46
- return (
47
- to .value in cls ._allowed_tensor_types
48
- and to_ignored .value in cls ._allowed_tensor_types
49
- )
71
+ def check (cls , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> orp .MatchResult :
72
+ check_result = orp .MatchResult ()
73
+ if to .value not in cls ._allowed_tensor_types :
74
+ return check_result .fail (f"Output type { to .value } is not allowed" )
75
+ if to_ignored .value not in cls ._allowed_tensor_types :
76
+ return check_result .fail (f"Ignored type { to_ignored .value } is not allowed" )
77
+ return check_result
50
78
51
79
@classmethod
52
80
def rewrite (cls , op , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ):
@@ -65,14 +93,19 @@ def rewrite(cls, op, x: ir.Value, shape: ir.Value):
65
93
return op .Identity (x )
66
94
67
95
@classmethod
68
- def check (cls , context , x , shape ) -> bool :
96
+ def check (cls , context , x , shape ) -> orp .MatchResult :
97
+ check_result = orp .MatchResult ()
69
98
if shape .const_value is None :
70
99
# Shape is not a constant and cannot be guessed.
71
- return False
100
+ return check_result . fail ( "Shape is not a constant and cannot be guessed." )
72
101
if (x_shape := x .shape ) is None :
73
102
# We don't know the shape of the input
74
- return False
75
- return x_shape .dims == tuple (shape .const_value .numpy ().tolist ())
103
+ return check_result .fail ("Input shape is not known." )
104
+ if x_shape .dims != tuple (shape .const_value .numpy ().tolist ()):
105
+ return check_result .fail (
106
+ f"Input shape { x_shape .dims } does not match the shape { shape .const_value .numpy ().tolist ()} ."
107
+ )
108
+ return check_result
76
109
77
110
78
111
class ReshapeReshape (orp .RewriteRuleAsClass ):
@@ -90,12 +123,15 @@ def rewrite(cls, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
90
123
return op .Reshape (x , shape )
91
124
92
125
@classmethod
93
- def check (cls , context , x , shape_ignored , shape ) -> bool :
94
- if shape_ignored .const_value is None or shape .const_value is None :
95
- return False
126
+ def check (cls , context , x , shape_ignored , shape ) -> orp .MatchResult :
127
+ check_result = orp .MatchResult ()
128
+ if shape_ignored .const_value is None :
129
+ return check_result .fail ("Shape ignored is not a constant." )
130
+ if shape .const_value is None :
131
+ return check_result .fail ("Shape is not a constant." )
96
132
if shape .const_value .numpy ().min () <= 0 :
97
- return False
98
- return True
133
+ return check_result . fail ( "Shape has non-positive values." )
134
+ return check_result
99
135
100
136
101
137
class SlicesSplit (orp .RewriteRuleAsClass ):
@@ -108,49 +144,50 @@ def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
108
144
return op .Slice (x , begin0 , end0 , axes0 ), op .Slice (x , begin1 , end1 , axes1 )
109
145
110
146
@classmethod
111
- def check (cls , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> bool :
147
+ def check (cls , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> orp .MatchResult :
148
+ check_result = orp .MatchResult ()
112
149
if (
113
150
axes0 .const_value is None
114
151
or axes1 .const_value is None
115
152
or axes0 .const_value .numpy ().tolist () != axes1 .const_value .numpy ().tolist ()
116
153
):
117
- return False
154
+ return check_result . fail ( "Axes are not equal or not constant." )
118
155
axes = axes0 .const_value .numpy ().tolist ()
119
156
if len (axes ) != 1 :
120
- return False
157
+ return check_result . fail ( "Axes has more than one dimension." )
121
158
if x .shape :
122
159
rk = len (x .shape )
123
160
else :
124
161
rk = x .rank
125
162
if axes [0 ] != - 1 and axes [0 ] != rk - 1 :
126
- return False
163
+ return check_result . fail ( "Axes is not -1 or last dimension." )
127
164
if (
128
165
begin0 .const_value is None
129
166
or end0 .const_value is None
130
167
or begin1 .const_value is None
131
168
or end1 .const_value is None
132
169
):
133
- return False
170
+ return check_result . fail ( "Begin or end are not constant values." )
134
171
if begin0 .const_value .numpy ().tolist () != [0 ]:
135
- return False
172
+ return check_result . fail ( "First begin value is not 0." )
136
173
e0 , b1 , e1 = (
137
174
end0 .const_value .numpy ().tolist (),
138
175
begin1 .const_value .numpy ().tolist (),
139
176
end1 .const_value .numpy ().tolist (),
140
177
)
141
178
if e0 [0 ] != b1 [0 ]:
142
- return False
179
+ return check_result . fail ( "End0 is not equal to Begin1." )
143
180
shape = x .shape
144
181
if shape is None :
145
- return False
182
+ return check_result . fail ( "Shape is not known." )
146
183
last_dim = shape [- 1 ]
147
184
if not isinstance (last_dim , int ):
148
- return False
185
+ return check_result . fail ( "Last dimension is not known." )
149
186
if last_dim != e1 [0 ]:
150
- return False
187
+ return check_result . fail ( "Last dimension is not equal to End1." )
151
188
if last_dim // 2 != b1 [0 ]:
152
- return False
153
- return True
189
+ return check_result . fail ( "Last dimension is not equal to Begin1." )
190
+ return check_result
154
191
155
192
@classmethod
156
193
def rewrite (cls , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
@@ -167,13 +204,14 @@ def pattern(cls, op, x, perm):
167
204
return op .Transpose (x , perm = perm )
168
205
169
206
@classmethod
170
- def check (cls , context , x : ir .Value , perm : ir .Attr ) -> bool :
207
+ def check (cls , context , x : ir .Value , perm : ir .Attr ) -> orp .MatchResult :
208
+ check_result = orp .MatchResult ()
171
209
if isinstance (perm , ir .RefAttr ):
172
- return False
210
+ return check_result . fail ( "Permutation is a reference attribute." )
173
211
if perm .type == ir .AttributeType .INTS :
174
212
if perm .value == list (range (len (perm .value ))):
175
- return True
176
- return False
213
+ return check_result
214
+ return check_result . fail ( "Permutation is not identity." )
177
215
178
216
@classmethod
179
217
def rewrite (cls , op , x : ir .Value , perm : ir .Attr ):
@@ -190,10 +228,11 @@ def pattern(cls, op, x, perm1, perm2):
190
228
return op .Transpose (op .Transpose (x , perm = perm1 ), perm = perm2 )
191
229
192
230
@classmethod
193
- def check (cls , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> bool :
231
+ def check (cls , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> orp .MatchResult :
232
+ check_result = orp .MatchResult ()
194
233
if isinstance (perm1 , ir .RefAttr ) or isinstance (perm2 , ir .RefAttr ):
195
- return False
196
- return True
234
+ return check_result . fail ( "Permutation is a reference attribute." )
235
+ return check_result
197
236
198
237
@classmethod
199
238
def _apply_transpose (cls , perm : tuple [int , ...], on : list [int ]) -> list [int ]:
@@ -237,17 +276,18 @@ def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
237
276
return op .Unsqueeze (x , op .Constant (value = ir .tensor (axes , dtype = ir .DataType .INT64 )))
238
277
239
278
@classmethod
240
- def check (cls , context , x , axes1 , axes2 ) -> bool :
279
+ def check (cls , context , x , axes1 , axes2 ) -> orp .MatchResult :
280
+ check_result = orp .MatchResult ()
241
281
del context # Unused
242
282
del x # Unused
243
283
# Currently restricted to single element positive axis
244
284
v1 = ir_utils .get_singleton_value (axes1 )
245
285
v2 = ir_utils .get_singleton_value (axes2 )
246
286
if v1 is None or v2 is None :
247
- return False
287
+ return check_result . fail ( "Axes are not constant." )
248
288
if (v1 < 0 ) or (v2 < 0 ):
249
- return False
250
- return True
289
+ return check_result . fail ( "Axes are negative." )
290
+ return check_result
251
291
252
292
253
293
cast_cast_rule = orp .make_rewrite_rule_from_class (CastCast )
0 commit comments