Skip to content

Make the return type of rewrite check functions a MatchResult object #2138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 61 additions & 41 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@
def rewrite(self, op, x: ir.Value):
return op.Identity(x)

def check(self, context, x) -> bool:
def check(self, context, x) -> orp.MatchResult:
del context # Unused
return ir_utils.has_rank(x, 1)
check_result = orp.MatchResult()
if not ir_utils.has_rank(x, 1):
return check_result.fail("Input is not 1D")
return check_result


class CastIdentity(orp.RewriteRuleAsClass):
Expand All @@ -43,8 +46,11 @@
return op.Identity(x)

@classmethod
def check(cls, context, x, to) -> bool:
return x.dtype == to.value
def check(cls, context, x, to) -> orp.MatchResult:
check_result = orp.MatchResult()
if x.dtype != to.value:
return check_result.fail("Input and output types are not the same")
return check_result


class CastCast(orp.RewriteRuleAsClass):
Expand All @@ -62,11 +68,13 @@
return op.Cast(op.Cast(x, to=to_ignored), to=to)

@classmethod
def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> bool:
return (
to.value in cls._allowed_tensor_types
and to_ignored.value in cls._allowed_tensor_types
)
def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
check_result = orp.MatchResult()
if to.value not in cls._allowed_tensor_types:
return check_result.fail(f"Output type {to.value} is not allowed")

Check warning on line 74 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L74

Added line #L74 was not covered by tests
if to_ignored.value not in cls._allowed_tensor_types:
return check_result.fail(f"Ignored type {to_ignored.value} is not allowed")

Check warning on line 76 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L76

Added line #L76 was not covered by tests
return check_result

@classmethod
def rewrite(cls, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
Expand All @@ -85,14 +93,19 @@
return op.Identity(x)

@classmethod
def check(cls, context, x, shape) -> bool:
def check(cls, context, x, shape) -> orp.MatchResult:
check_result = orp.MatchResult()
if shape.const_value is None:
# Shape is not a constant and cannot be guessed.
return False
return check_result.fail("Shape is not a constant and cannot be guessed.")

Check warning on line 100 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L100

Added line #L100 was not covered by tests
if (x_shape := x.shape) is None:
# We don't know the shape of the input
return False
return x_shape.dims == tuple(shape.const_value.numpy().tolist())
return check_result.fail("Input shape is not known.")
if x_shape.dims != tuple(shape.const_value.numpy().tolist()):
return check_result.fail(
f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}."
)
return check_result


class ReshapeReshape(orp.RewriteRuleAsClass):
Expand All @@ -110,12 +123,15 @@
return op.Reshape(x, shape)

@classmethod
def check(cls, context, x, shape_ignored, shape) -> bool:
if shape_ignored.const_value is None or shape.const_value is None:
return False
def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult:
check_result = orp.MatchResult()
if shape_ignored.const_value is None:
return check_result.fail("Shape ignored is not a constant.")

Check warning on line 129 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L129

Added line #L129 was not covered by tests
if shape.const_value is None:
return check_result.fail("Shape is not a constant.")

Check warning on line 131 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L131

Added line #L131 was not covered by tests
if shape.const_value.numpy().min() <= 0:
return False
return True
return check_result.fail("Shape has non-positive values.")

Check warning on line 133 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L133

Added line #L133 was not covered by tests
return check_result


class SlicesSplit(orp.RewriteRuleAsClass):
Expand All @@ -128,49 +144,50 @@
return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1)

@classmethod
def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> bool:
def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult:
check_result = orp.MatchResult()

Check warning on line 148 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L148

Added line #L148 was not covered by tests
if (
axes0.const_value is None
or axes1.const_value is None
or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist()
):
return False
return check_result.fail("Axes are not equal or not constant.")

Check warning on line 154 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L154

Added line #L154 was not covered by tests
axes = axes0.const_value.numpy().tolist()
if len(axes) != 1:
return False
return check_result.fail("Axes has more than one dimension.")

Check warning on line 157 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L157

Added line #L157 was not covered by tests
if x.shape:
rk = len(x.shape)
else:
rk = x.rank
if axes[0] != -1 and axes[0] != rk - 1:
return False
return check_result.fail("Axes is not -1 or last dimension.")

Check warning on line 163 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L163

Added line #L163 was not covered by tests
if (
begin0.const_value is None
or end0.const_value is None
or begin1.const_value is None
or end1.const_value is None
):
return False
return check_result.fail("Begin or end are not constant values.")

Check warning on line 170 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L170

Added line #L170 was not covered by tests
if begin0.const_value.numpy().tolist() != [0]:
return False
return check_result.fail("First begin value is not 0.")

Check warning on line 172 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L172

Added line #L172 was not covered by tests
e0, b1, e1 = (
end0.const_value.numpy().tolist(),
begin1.const_value.numpy().tolist(),
end1.const_value.numpy().tolist(),
)
if e0[0] != b1[0]:
return False
return check_result.fail("End0 is not equal to Begin1.")

Check warning on line 179 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L179

Added line #L179 was not covered by tests
shape = x.shape
if shape is None:
return False
return check_result.fail("Shape is not known.")

Check warning on line 182 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L182

Added line #L182 was not covered by tests
last_dim = shape[-1]
if not isinstance(last_dim, int):
return False
return check_result.fail("Last dimension is not known.")

Check warning on line 185 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L185

Added line #L185 was not covered by tests
if last_dim != e1[0]:
return False
return check_result.fail("Last dimension is not equal to End1.")

Check warning on line 187 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L187

Added line #L187 was not covered by tests
if last_dim // 2 != b1[0]:
return False
return True
return check_result.fail("Last dimension is not equal to Begin1.")
return check_result

Check warning on line 190 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L189-L190

Added lines #L189 - L190 were not covered by tests

@classmethod
def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
Expand All @@ -187,13 +204,14 @@
return op.Transpose(x, perm=perm)

@classmethod
def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool:
def check(cls, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult:
check_result = orp.MatchResult()
if isinstance(perm, ir.RefAttr):
return False
return check_result.fail("Permutation is a reference attribute.")

Check warning on line 210 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L210

Added line #L210 was not covered by tests
if perm.type == ir.AttributeType.INTS:
if perm.value == list(range(len(perm.value))):
return True
return False
return check_result
return check_result.fail("Permutation is not identity.")

@classmethod
def rewrite(cls, op, x: ir.Value, perm: ir.Attr):
Expand All @@ -210,10 +228,11 @@
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)

@classmethod
def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> bool:
def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult:
check_result = orp.MatchResult()
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
return False
return True
return check_result.fail("Permutation is a reference attribute.")

Check warning on line 234 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L234

Added line #L234 was not covered by tests
return check_result

@classmethod
def _apply_transpose(cls, perm: tuple[int, ...], on: list[int]) -> list[int]:
Expand Down Expand Up @@ -257,17 +276,18 @@
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))

@classmethod
def check(cls, context, x, axes1, axes2) -> bool:
def check(cls, context, x, axes1, axes2) -> orp.MatchResult:
check_result = orp.MatchResult()
del context # Unused
del x # Unused
# Currently restricted to single element positive axis
v1 = ir_utils.get_singleton_value(axes1)
v2 = ir_utils.get_singleton_value(axes2)
if v1 is None or v2 is None:
return False
return check_result.fail("Axes are not constant.")

Check warning on line 287 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L287

Added line #L287 was not covered by tests
if (v1 < 0) or (v2 < 0):
return False
return True
return check_result.fail("Axes are negative.")

Check warning on line 289 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L289

Added line #L289 was not covered by tests
return check_result


cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast)
Expand Down
20 changes: 14 additions & 6 deletions onnxscript/rewriter/ort_fusions/cos_sin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,32 @@
_domain="ai.onnxruntime.fusion",
)

def check(self, context, inv_freq, position_ids, freqs, extra_dims, **_):
def check(
self, context, inv_freq, position_ids, freqs, extra_dims, **_
) -> pattern.MatchResult: # type: ignore[name-defined]
check_result = pattern.MatchResult()
# TODO(rama): handle redundant reshape/expand
if self._const_freqs:
return (freqs.const_value is not None) and _ir_utils.has_rank(freqs, 3)
if (freqs.const_value is None) or not _ir_utils.has_rank(freqs, 3):
return check_result.fail("freqs is not a constant or not 3D.", freqs)
else:
return check_result

Check warning on line 108 in onnxscript/rewriter/ort_fusions/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/cos_sin_cache.py#L108

Added line #L108 was not covered by tests
if (
_ir_utils.has_rank(position_ids, 2) and _ir_utils.is_singleton_value(extra_dims, 1)
) or (
_ir_utils.has_rank(position_ids, 1) and _ir_utils.is_1d_value(extra_dims, [0, 1])
):
pass
else:
return False
return check_result.fail("position_ids is not a 1D or 2D tensor.", position_ids)

Check warning on line 116 in onnxscript/rewriter/ort_fusions/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/cos_sin_cache.py#L116

Added line #L116 was not covered by tests
if not _ir_utils.has_rank(inv_freq, 3):
return False
return check_result.fail("inv_freq is not 3D.", inv_freq)

Check warning on line 118 in onnxscript/rewriter/ort_fusions/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/cos_sin_cache.py#L118

Added line #L118 was not covered by tests
inv_freq_shape = inv_freq.shape
if inv_freq.const_value is None: # TODO: should this be inv_freq_shape?
return False
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1
return check_result.fail("inv_freq is not a constant.", inv_freq)

Check warning on line 121 in onnxscript/rewriter/ort_fusions/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/cos_sin_cache.py#L121

Added line #L121 was not covered by tests
if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1:
return check_result.fail("inv_freq is not of shape [1, ., 1].", inv_freq)

Check warning on line 123 in onnxscript/rewriter/ort_fusions/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/cos_sin_cache.py#L123

Added line #L123 was not covered by tests
return check_result

def rewrite(
self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, **_
Expand Down
32 changes: 20 additions & 12 deletions onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
return op.Div(op.MatMul(x, y), cst)

@classmethod
def check(cls, context, x, y, cst) -> bool:
def check(cls, context, x, y, cst) -> orp.MatchResult:
check_result = orp.MatchResult()
if cst.const_value is None:
return False
return check_result.fail("Divisor is not a constant value.")

Check warning on line 21 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L21

Added line #L21 was not covered by tests
value = cst.const_value.numpy()
if value.size > 1:
return False
return True
return check_result.fail("Divisor is not a scalar value.")

Check warning on line 24 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L24

Added line #L24 was not covered by tests
return check_result

@classmethod
def rewrite(cls, op, x, y, cst):
Expand All @@ -38,12 +39,13 @@
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst)

@classmethod
def check(cls, context, x, y, cst) -> bool:
def check(cls, context, x, y, cst) -> orp.MatchResult:
check_result = orp.MatchResult()
if cst.const_value is None:
return False
return check_result.fail("Divisor is not a constant value.")

Check warning on line 45 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L45

Added line #L45 was not covered by tests
if cst.const_value.numpy().size > 1:
return False
return True
return check_result.fail("Divisor is not a scalar value.")

Check warning on line 47 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L47

Added line #L47 was not covered by tests
return check_result

@classmethod
def rewrite(cls, op, x, y, cst):
Expand All @@ -65,11 +67,14 @@
_pos: ClassVar = 1

@classmethod
def check(cls, context, x, y) -> bool:
def check(cls, context, x, y) -> orp.MatchResult:
check_result = orp.MatchResult()
perm = list((x if cls._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
return perm == expected_perm
if perm != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")

Check warning on line 76 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L76

Added line #L76 was not covered by tests
return check_result

@classmethod
def rewrite(cls, op, x, y):
Expand Down Expand Up @@ -126,13 +131,16 @@
return op.Transpose(op.MatMul(x, y))

@classmethod
def check(cls, context, x, y) -> bool:
def check(cls, context, x, y) -> orp.MatchResult:
check_result = orp.MatchResult()
matmul = list(x.uses())[0][0] # noqa: RUF015
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015
perm = transpose.attributes["perm"].value
expected_perm = list(range(len(perm)))
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
return perm == expected_perm
if perm != expected_perm:
return check_result.fail("Permutation values for Transpose are not correct.")

Check warning on line 142 in onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py#L142

Added line #L142 was not covered by tests
return check_result

@classmethod
def rewrite(cls, op, x, y):
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/rewriter/ort_fusions/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@
# key_transposed,
# attention_reshaped,
**_,
):
) -> pattern.MatchResult: # type: ignore[name-defined]
check_result = pattern.MatchResult()

Check warning on line 98 in onnxscript/rewriter/ort_fusions/gqa.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/gqa.py#L98

Added line #L98 was not covered by tests
# bindings: dict[str, int] = {}
# status = (
# _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"])
Expand All @@ -110,7 +111,7 @@
# return False
# if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]:
# return False
return True
return check_result

Check warning on line 114 in onnxscript/rewriter/ort_fusions/gqa.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/gqa.py#L114

Added line #L114 was not covered by tests

def rewrite(
self,
Expand Down
Loading
Loading