Skip to content

Commit 88c10a3

Browse files
shubhambhokare1bmehta001
authored andcommitted
Make the return type of rewrite check functions a MatchResult object (microsoft#2138)
- Check function returns a MatchResult object instead of bool - This allows propagating the failure reason to the tracer to help in debugging
1 parent dbb5f02 commit 88c10a3

File tree

10 files changed

+320
-108
lines changed

10 files changed

+320
-108
lines changed

onnxscript/rewriter/llama_rule_sets.py

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,29 @@
1111
from onnxscript.rewriter import pattern as orp
1212

1313

14+
class SqueezeReshape(orp.RewriteRuleClassBase):
15+
"""Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.
16+
17+
This pattern arises from the translation of pytorch symints.
18+
"""
19+
20+
def __init__(self):
21+
super().__init__("SqueezeReshape1d", remove_nodes=False)
22+
23+
def pattern(self, op, x):
24+
return op.Reshape(op.Squeeze(x), [-1])
25+
26+
def rewrite(self, op, x: ir.Value):
27+
return op.Identity(x)
28+
29+
def check(self, context, x) -> orp.MatchResult:
30+
del context # Unused
31+
check_result = orp.MatchResult()
32+
if not ir_utils.has_rank(x, 1):
33+
return check_result.fail("Input is not 1D")
34+
return check_result
35+
36+
1437
class CastIdentity(orp.RewriteRuleAsClass):
1538
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
1639

@@ -23,8 +46,11 @@ def rewrite(cls, op, x: ir.Value, to: ir.Attr):
2346
return op.Identity(x)
2447

2548
@classmethod
26-
def check(cls, context, x, to) -> bool:
27-
return x.dtype == to.value
49+
def check(cls, context, x, to) -> orp.MatchResult:
50+
check_result = orp.MatchResult()
51+
if x.dtype != to.value:
52+
return check_result.fail("Input and output types are not the same")
53+
return check_result
2854

2955

3056
class CastCast(orp.RewriteRuleAsClass):
@@ -42,11 +68,13 @@ def pattern(cls, op, x, to, to_ignored):
4268
return op.Cast(op.Cast(x, to=to_ignored), to=to)
4369

4470
@classmethod
45-
def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> bool:
46-
return (
47-
to.value in cls._allowed_tensor_types
48-
and to_ignored.value in cls._allowed_tensor_types
49-
)
71+
def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
72+
check_result = orp.MatchResult()
73+
if to.value not in cls._allowed_tensor_types:
74+
return check_result.fail(f"Output type {to.value} is not allowed")
75+
if to_ignored.value not in cls._allowed_tensor_types:
76+
return check_result.fail(f"Ignored type {to_ignored.value} is not allowed")
77+
return check_result
5078

5179
@classmethod
5280
def rewrite(cls, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
@@ -65,14 +93,19 @@ def rewrite(cls, op, x: ir.Value, shape: ir.Value):
6593
return op.Identity(x)
6694

6795
@classmethod
68-
def check(cls, context, x, shape) -> bool:
96+
def check(cls, context, x, shape) -> orp.MatchResult:
97+
check_result = orp.MatchResult()
6998
if shape.const_value is None:
7099
# Shape is not a constant and cannot be guessed.
71-
return False
100+
return check_result.fail("Shape is not a constant and cannot be guessed.")
72101
if (x_shape := x.shape) is None:
73102
# We don't know the shape of the input
74-
return False
75-
return x_shape.dims == tuple(shape.const_value.numpy().tolist())
103+
return check_result.fail("Input shape is not known.")
104+
if x_shape.dims != tuple(shape.const_value.numpy().tolist()):
105+
return check_result.fail(
106+
f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}."
107+
)
108+
return check_result
76109

77110

78111
class ReshapeReshape(orp.RewriteRuleAsClass):
@@ -90,12 +123,15 @@ def rewrite(cls, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
90123
return op.Reshape(x, shape)
91124

92125
@classmethod
93-
def check(cls, context, x, shape_ignored, shape) -> bool:
94-
if shape_ignored.const_value is None or shape.const_value is None:
95-
return False
126+
def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult:
127+
check_result = orp.MatchResult()
128+
if shape_ignored.const_value is None:
129+
return check_result.fail("Shape ignored is not a constant.")
130+
if shape.const_value is None:
131+
return check_result.fail("Shape is not a constant.")
96132
if shape.const_value.numpy().min() <= 0:
97-
return False
98-
return True
133+
return check_result.fail("Shape has non-positive values.")
134+
return check_result
99135

100136

101137
class SlicesSplit(orp.RewriteRuleAsClass):
@@ -108,49 +144,50 @@ def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
108144
return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1)
109145

110146
@classmethod
111-
def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> bool:
147+
def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult:
148+
check_result = orp.MatchResult()
112149
if (
113150
axes0.const_value is None
114151
or axes1.const_value is None
115152
or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist()
116153
):
117-
return False
154+
return check_result.fail("Axes are not equal or not constant.")
118155
axes = axes0.const_value.numpy().tolist()
119156
if len(axes) != 1:
120-
return False
157+
return check_result.fail("Axes has more than one dimension.")
121158
if x.shape:
122159
rk = len(x.shape)
123160
else:
124161
rk = x.rank
125162
if axes[0] != -1 and axes[0] != rk - 1:
126-
return False
163+
return check_result.fail("Axes is not -1 or last dimension.")
127164
if (
128165
begin0.const_value is None
129166
or end0.const_value is None
130167
or begin1.const_value is None
131168
or end1.const_value is None
132169
):
133-
return False
170+
return check_result.fail("Begin or end are not constant values.")
134171
if begin0.const_value.numpy().tolist() != [0]:
135-
return False
172+
return check_result.fail("First begin value is not 0.")
136173
e0, b1, e1 = (
137174
end0.const_value.numpy().tolist(),
138175
begin1.const_value.numpy().tolist(),
139176
end1.const_value.numpy().tolist(),
140177
)
141178
if e0[0] != b1[0]:
142-
return False
179+
return check_result.fail("End0 is not equal to Begin1.")
143180
shape = x.shape
144181
if shape is None:
145-
return False
182+
return check_result.fail("Shape is not known.")
146183
last_dim = shape[-1]
147184
if not isinstance(last_dim, int):
148-
return False
185+
return check_result.fail("Last dimension is not known.")
149186
if last_dim != e1[0]:
150-
return False
187+
return check_result.fail("Last dimension is not equal to End1.")
151188
if last_dim // 2 != b1[0]:
152-
return False
153-
return True
189+
return check_result.fail("Last dimension is not equal to Begin1.")
190+
return check_result
154191

155192
@classmethod
156193
def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
@@ -167,13 +204,14 @@ def pattern(cls, op, x, perm):
167204
return op.Transpose(x, perm=perm)
168205

169206
@classmethod
170-
def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool:
207+
def check(cls, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult:
208+
check_result = orp.MatchResult()
171209
if isinstance(perm, ir.RefAttr):
172-
return False
210+
return check_result.fail("Permutation is a reference attribute.")
173211
if perm.type == ir.AttributeType.INTS:
174212
if perm.value == list(range(len(perm.value))):
175-
return True
176-
return False
213+
return check_result
214+
return check_result.fail("Permutation is not identity.")
177215

178216
@classmethod
179217
def rewrite(cls, op, x: ir.Value, perm: ir.Attr):
@@ -190,10 +228,11 @@ def pattern(cls, op, x, perm1, perm2):
190228
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)
191229

192230
@classmethod
193-
def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> bool:
231+
def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult:
232+
check_result = orp.MatchResult()
194233
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
195-
return False
196-
return True
234+
return check_result.fail("Permutation is a reference attribute.")
235+
return check_result
197236

198237
@classmethod
199238
def _apply_transpose(cls, perm: tuple[int, ...], on: list[int]) -> list[int]:
@@ -237,17 +276,18 @@ def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
237276
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))
238277

239278
@classmethod
240-
def check(cls, context, x, axes1, axes2) -> bool:
279+
def check(cls, context, x, axes1, axes2) -> orp.MatchResult:
280+
check_result = orp.MatchResult()
241281
del context # Unused
242282
del x # Unused
243283
# Currently restricted to single element positive axis
244284
v1 = ir_utils.get_singleton_value(axes1)
245285
v2 = ir_utils.get_singleton_value(axes2)
246286
if v1 is None or v2 is None:
247-
return False
287+
return check_result.fail("Axes are not constant.")
248288
if (v1 < 0) or (v2 < 0):
249-
return False
250-
return True
289+
return check_result.fail("Axes are negative.")
290+
return check_result
251291

252292

253293
cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast)

onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ def pattern(cls, op, x, y, cst):
1515
return op.Div(op.MatMul(x, y), cst)
1616

1717
@classmethod
18-
def check(cls, context, x, y, cst) -> bool:
18+
def check(cls, context, x, y, cst) -> orp.MatchResult:
19+
check_result = orp.MatchResult()
1920
if cst.const_value is None:
20-
return False
21+
return check_result.fail("Divisor is not a constant value.")
2122
value = cst.const_value.numpy()
2223
if value.size > 1:
23-
return False
24-
return True
24+
return check_result.fail("Divisor is not a scalar value.")
25+
return check_result
2526

2627
@classmethod
2728
def rewrite(cls, op, x, y, cst):
@@ -38,12 +39,13 @@ def pattern(cls, op, x, y, cst):
3839
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst)
3940

4041
@classmethod
41-
def check(cls, context, x, y, cst) -> bool:
42+
def check(cls, context, x, y, cst) -> orp.MatchResult:
43+
check_result = orp.MatchResult()
4244
if cst.const_value is None:
43-
return False
45+
return check_result.fail("Divisor is not a constant value.")
4446
if cst.const_value.numpy().size > 1:
45-
return False
46-
return True
47+
return check_result.fail("Divisor is not a scalar value.")
48+
return check_result
4749

4850
@classmethod
4951
def rewrite(cls, op, x, y, cst):
@@ -65,11 +67,14 @@ class _TransposeMatMulBase(orp.RewriteRuleAsClass):
6567
_pos: ClassVar = 1
6668

6769
@classmethod
68-
def check(cls, context, x, y) -> bool:
70+
def check(cls, context, x, y) -> orp.MatchResult:
71+
check_result = orp.MatchResult()
6972
perm = list((x if cls._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015
7073
expected_perm = list(range(len(perm)))
7174
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
72-
return perm == expected_perm
75+
if perm != expected_perm:
76+
return check_result.fail("Permutation values for Transpose are not correct.")
77+
return check_result
7378

7479
@classmethod
7580
def rewrite(cls, op, x, y):
@@ -126,13 +131,16 @@ def pattern(cls, op, x, y):
126131
return op.Transpose(op.MatMul(x, y))
127132

128133
@classmethod
129-
def check(cls, context, x, y) -> bool:
134+
def check(cls, context, x, y) -> orp.MatchResult:
135+
check_result = orp.MatchResult()
130136
matmul = list(x.uses())[0][0] # noqa: RUF015
131137
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015
132138
perm = transpose.attributes["perm"].value
133139
expected_perm = list(range(len(perm)))
134140
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
135-
return perm == expected_perm
141+
if perm != expected_perm:
142+
return check_result.fail("Permutation values for Transpose are not correct.")
143+
return check_result
136144

137145
@classmethod
138146
def rewrite(cls, op, x, y):

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,32 @@ def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs,
9292
_domain="ai.onnxruntime.fusion",
9393
)
9494

95-
def check(self, context, inv_freq, position_ids, freqs, **_):
95+
def check(
96+
self, context, inv_freq, position_ids, freqs, extra_dims, **_
97+
) -> pattern.MatchResult: # type: ignore[name-defined]
98+
check_result = pattern.MatchResult()
9699
# TODO(rama): handle redundant reshape/expand
97100
if self._const_freqs:
98-
return (freqs.const_value is not None) and _ir_utils.has_rank(freqs, 3)
99-
if not _ir_utils.has_rank(position_ids, 2):
100-
return False
101+
if (freqs.const_value is None) or not _ir_utils.has_rank(freqs, 3):
102+
return check_result.fail("freqs is not a constant or not 3D.", freqs)
103+
else:
104+
return check_result
105+
if (
106+
_ir_utils.has_rank(position_ids, 2) and _ir_utils.is_singleton_value(extra_dims, 1)
107+
) or (
108+
_ir_utils.has_rank(position_ids, 1) and _ir_utils.is_1d_value(extra_dims, [0, 1])
109+
):
110+
pass
111+
else:
112+
return check_result.fail("position_ids is not a 1D or 2D tensor.", position_ids)
101113
if not _ir_utils.has_rank(inv_freq, 3):
102-
return False
114+
return check_result.fail("inv_freq is not 3D.", inv_freq)
103115
inv_freq_shape = inv_freq.shape
104116
if inv_freq.const_value is None: # TODO: should this be inv_freq_shape?
105-
return False
106-
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1
117+
return check_result.fail("inv_freq is not a constant.", inv_freq)
118+
if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1:
119+
return check_result.fail("inv_freq is not of shape [1, ., 1].", inv_freq)
120+
return check_result
107121

108122
def rewrite(
109123
self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, **_

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def check(
9494
# key_transposed,
9595
# attention_reshaped,
9696
**_,
97-
):
97+
) -> pattern.MatchResult: # type: ignore[name-defined]
98+
check_result = pattern.MatchResult()
9899
# bindings: dict[str, int] = {}
99100
# status = (
100101
# _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"])
@@ -110,7 +111,7 @@ def check(
110111
# return False
111112
# if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]:
112113
# return False
113-
return True
114+
return check_result
114115

115116
def rewrite(
116117
self,

0 commit comments

Comments
 (0)