diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index dd8c2aedaf..f721bf5c9e 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -26,9 +26,12 @@ def pattern(self, op, x): def rewrite(self, op, x: ir.Value): return op.Identity(x) - def check(self, context, x) -> bool: + def check(self, context, x) -> orp.MatchResult: del context # Unused - return ir_utils.has_rank(x, 1) + check_result = orp.MatchResult() + if not ir_utils.has_rank(x, 1): + return check_result.fail("Input is not 1D") + return check_result class CastIdentity(orp.RewriteRuleAsClass): @@ -43,8 +46,11 @@ def rewrite(cls, op, x: ir.Value, to: ir.Attr): return op.Identity(x) @classmethod - def check(cls, context, x, to) -> bool: - return x.dtype == to.value + def check(cls, context, x, to) -> orp.MatchResult: + check_result = orp.MatchResult() + if x.dtype != to.value: + return check_result.fail("Input and output types are not the same") + return check_result class CastCast(orp.RewriteRuleAsClass): @@ -62,11 +68,13 @@ def pattern(cls, op, x, to, to_ignored): return op.Cast(op.Cast(x, to=to_ignored), to=to) @classmethod - def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> bool: - return ( - to.value in cls._allowed_tensor_types - and to_ignored.value in cls._allowed_tensor_types - ) + def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult: + check_result = orp.MatchResult() + if to.value not in cls._allowed_tensor_types: + return check_result.fail(f"Output type {to.value} is not allowed") + if to_ignored.value not in cls._allowed_tensor_types: + return check_result.fail(f"Ignored type {to_ignored.value} is not allowed") + return check_result @classmethod def rewrite(cls, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr): @@ -85,14 +93,19 @@ def rewrite(cls, op, x: ir.Value, shape: ir.Value): return op.Identity(x) @classmethod - def check(cls, context, x, shape) -> bool: + def check(cls, context, x, shape) -> orp.MatchResult: + check_result = orp.MatchResult() if shape.const_value is None: # Shape is not a constant and cannot be guessed. - return False + return check_result.fail("Shape is not a constant and cannot be guessed.") if (x_shape := x.shape) is None: # We don't know the shape of the input - return False - return x_shape.dims == tuple(shape.const_value.numpy().tolist()) + return check_result.fail("Input shape is not known.") + if x_shape.dims != tuple(shape.const_value.numpy().tolist()): + return check_result.fail( + f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}." + ) + return check_result class ReshapeReshape(orp.RewriteRuleAsClass): @@ -110,12 +123,15 @@ def rewrite(cls, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value): return op.Reshape(x, shape) @classmethod - def check(cls, context, x, shape_ignored, shape) -> bool: - if shape_ignored.const_value is None or shape.const_value is None: - return False + def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult: + check_result = orp.MatchResult() + if shape_ignored.const_value is None: + return check_result.fail("Shape ignored is not a constant.") + if shape.const_value is None: + return check_result.fail("Shape is not a constant.") if shape.const_value.numpy().min() <= 0: - return False - return True + return check_result.fail("Shape has non-positive values.") + return check_result class SlicesSplit(orp.RewriteRuleAsClass): @@ -128,49 +144,50 @@ def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1): return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1) @classmethod - def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> bool: + def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult: + check_result = orp.MatchResult() if ( axes0.const_value is None or axes1.const_value is None or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist() ): - return False + return check_result.fail("Axes are not equal or not constant.") axes = axes0.const_value.numpy().tolist() if len(axes) != 1: - return False + return check_result.fail("Axes has more than one dimension.") if x.shape: rk = len(x.shape) else: rk = x.rank if axes[0] != -1 and axes[0] != rk - 1: - return False + return check_result.fail("Axes is not -1 or last dimension.") if ( begin0.const_value is None or end0.const_value is None or begin1.const_value is None or end1.const_value is None ): - return False + return check_result.fail("Begin or end are not constant values.") if begin0.const_value.numpy().tolist() != [0]: - return False + return check_result.fail("First begin value is not 0.") e0, b1, e1 = ( end0.const_value.numpy().tolist(), begin1.const_value.numpy().tolist(), end1.const_value.numpy().tolist(), ) if e0[0] != b1[0]: - return False + return check_result.fail("End0 is not equal to Begin1.") shape = x.shape if shape is None: - return False + return check_result.fail("Shape is not known.") last_dim = shape[-1] if not isinstance(last_dim, int): - return False + return check_result.fail("Last dimension is not known.") if last_dim != e1[0]: - return False + return check_result.fail("Last dimension is not equal to End1.") if last_dim // 2 != b1[0]: - return False - return True + return check_result.fail("Last dimension is not equal to Begin1.") + return check_result @classmethod def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1): @@ -187,13 +204,14 @@ def pattern(cls, op, x, perm): return op.Transpose(x, perm=perm) @classmethod - def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: + def check(cls, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult: + check_result = orp.MatchResult() if isinstance(perm, ir.RefAttr): - return False + return check_result.fail("Permutation is a reference attribute.") if perm.type == ir.AttributeType.INTS: if perm.value == list(range(len(perm.value))): - return True - return False + return check_result + return check_result.fail("Permutation is not identity.") @classmethod def rewrite(cls, op, x: ir.Value, perm: ir.Attr): @@ -210,10 +228,11 @@ def pattern(cls, op, x, perm1, perm2): return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2) @classmethod - def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> bool: + def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult: + check_result = orp.MatchResult() if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr): - return False - return True + return check_result.fail("Permutation is a reference attribute.") + return check_result @classmethod def _apply_transpose(cls, perm: tuple[int, ...], on: list[int]) -> list[int]: @@ -257,17 +276,18 @@ def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value): return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64))) @classmethod - def check(cls, context, x, axes1, axes2) -> bool: + def check(cls, context, x, axes1, axes2) -> orp.MatchResult: + check_result = orp.MatchResult() del context # Unused del x # Unused # Currently restricted to single element positive axis v1 = ir_utils.get_singleton_value(axes1) v2 = ir_utils.get_singleton_value(axes2) if v1 is None or v2 is None: - return False + return check_result.fail("Axes are not constant.") if (v1 < 0) or (v2 < 0): - return False - return True + return check_result.fail("Axes are negative.") + return check_result cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast) diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index 476226c6a2..cf0522c5ad 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -96,10 +96,16 @@ def pattern( _domain="ai.onnxruntime.fusion", ) - def check(self, context, inv_freq, position_ids, freqs, extra_dims, **_): + def check( + self, context, inv_freq, position_ids, freqs, extra_dims, **_ + ) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() # TODO(rama): handle redundant reshape/expand if self._const_freqs: - return (freqs.const_value is not None) and _ir_utils.has_rank(freqs, 3) + if (freqs.const_value is None) or not _ir_utils.has_rank(freqs, 3): + return check_result.fail("freqs is not a constant or not 3D.", freqs) + else: + return check_result if ( _ir_utils.has_rank(position_ids, 2) and _ir_utils.is_singleton_value(extra_dims, 1) ) or ( @@ -107,13 +113,15 @@ def check(self, context, inv_freq, position_ids, freqs, extra_dims, **_): ): pass else: - return False + return check_result.fail("position_ids is not a 1D or 2D tensor.", position_ids) if not _ir_utils.has_rank(inv_freq, 3): - return False + return check_result.fail("inv_freq is not 3D.", inv_freq) inv_freq_shape = inv_freq.shape if inv_freq.const_value is None: # TODO: should this be inv_freq_shape? - return False - return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 + return check_result.fail("inv_freq is not a constant.", inv_freq) + if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1: + return check_result.fail("inv_freq is not of shape [1, ., 1].", inv_freq) + return check_result def rewrite( self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, **_ diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index 65496ec8bd..d60d8ad300 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -15,13 +15,14 @@ def pattern(cls, op, x, y, cst): return op.Div(op.MatMul(x, y), cst) @classmethod - def check(cls, context, x, y, cst) -> bool: + def check(cls, context, x, y, cst) -> orp.MatchResult: + check_result = orp.MatchResult() if cst.const_value is None: - return False + return check_result.fail("Divisor is not a constant value.") value = cst.const_value.numpy() if value.size > 1: - return False - return True + return check_result.fail("Divisor is not a scalar value.") + return check_result @classmethod def rewrite(cls, op, x, y, cst): @@ -38,12 +39,13 @@ def pattern(cls, op, x, y, cst): return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst) @classmethod - def check(cls, context, x, y, cst) -> bool: + def check(cls, context, x, y, cst) -> orp.MatchResult: + check_result = orp.MatchResult() if cst.const_value is None: - return False + return check_result.fail("Divisor is not a constant value.") if cst.const_value.numpy().size > 1: - return False - return True + return check_result.fail("Divisor is not a scalar value.") + return check_result @classmethod def rewrite(cls, op, x, y, cst): @@ -65,11 +67,14 @@ class _TransposeMatMulBase(orp.RewriteRuleAsClass): _pos: ClassVar = 1 @classmethod - def check(cls, context, x, y) -> bool: + def check(cls, context, x, y) -> orp.MatchResult: + check_result = orp.MatchResult() perm = list((x if cls._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015 expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - return perm == expected_perm + if perm != expected_perm: + return check_result.fail("Permutation values for Transpose are not correct.") + return check_result @classmethod def rewrite(cls, op, x, y): @@ -126,13 +131,16 @@ def pattern(cls, op, x, y): return op.Transpose(op.MatMul(x, y)) @classmethod - def check(cls, context, x, y) -> bool: + def check(cls, context, x, y) -> orp.MatchResult: + check_result = orp.MatchResult() matmul = list(x.uses())[0][0] # noqa: RUF015 transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015 perm = transpose.attributes["perm"].value expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - return perm == expected_perm + if perm != expected_perm: + return check_result.fail("Permutation values for Transpose are not correct.") + return check_result @classmethod def rewrite(cls, op, x, y): diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 4bad28c789..b57519ad17 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -94,7 +94,8 @@ def check( # key_transposed, # attention_reshaped, **_, - ): + ) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() # bindings: dict[str, int] = {} # status = ( # _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) @@ -110,7 +111,7 @@ def check( # return False # if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]: # return False - return True + return check_result def rewrite( self, diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 0563dc4edd..8bb85f2aed 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -163,29 +163,54 @@ def check( key_BSHDh, value_BSHDh, **_, - ): + ) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: return not _check_shape(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): - return False + return check_result.fail( + f"Shape mismatch: {query_BSD} does not match expected dimensions ['B', 'S', 'D']", + query_BSD, + ) if no_match(key_BSD, ["B", "Skv", "D"]): - return False + return check_result.fail( + f"Shape mismatch: {key_BSD} does not match expected dimensions ['B', 'Skv', 'D']", + query_BSD, + ) if no_match(value_BSD, ["B", "Skv", "D"]): - return False + return check_result.fail( + f"Shape mismatch: {value_BSD} does not match expected dimensions ['B', 'Skv', 'D']", + value_BSD, + ) if no_match(past_key, ["B", "H", "Spast", "Dh"]): - return False + return check_result.fail( + f"Shape mismatch: {past_key} does not match expected dimensions ['B', 'H', 'Spast', 'Dh']", + past_key, + ) if no_match(past_value, ["B", "H", "Spast", "Dv"]): - return False + return check_result.fail( + f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']", + past_value, + ) if no_match(query_BSHDh, ["B", "S", "H", "Dh"]): - return False + return check_result.fail( + f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']", + query_BSHDh, + ) if no_match(key_BSHDh, ["B", "S", "H", "Dh"]): - return False + return check_result.fail( + f"Shape mismatch: {key_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']", + query_BSHDh, + ) if no_match(value_BSHDh, ["B", "S", "H", "Dh"]): - return False + return check_result.fail( + f"Shape mismatch: {value_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']", + query_BSHDh, + ) # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St) # But this also, unforunately, depends on ORT version. @@ -193,7 +218,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: # or check Reshape's shape-input value - return True + return check_result def rewrite( self, diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py index 4cea9d7b90..55b7f190b2 100644 --- a/onnxscript/rewriter/ort_fusions/rms_normalization.py +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -52,21 +52,22 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): normalized = op.Cast(normalized, to=target_dtype) return op.Mul(scale, normalized) - def check(self, op, x, scale, epsilon, compute_dtype, target_dtype): + def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: # type: ignore[name-defined] """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" + check_result = pattern.MatchResult() # epsilon must be a scalar epsilon_value = _ir_utils.get_singleton_value(epsilon) if not isinstance(epsilon_value, float): # TODO: support other types - return False + return check_result.fail("Epsilon is not a float value.", epsilon) # input and output must be same dtype if x.dtype not in float_types: - return False + return check_result.fail("Input is not a float type.", x) if scale.dtype not in float_types: - return False + return check_result.fail("Scale is not a float type.", scale) stash_dtype = compute_dtype.value if self._cast_input else x.dtype if stash_dtype not in fp_float_types: - return False - return True + return check_result.fail("Normalization precision is not a float or double type.") + return check_result def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): stash_dtype = compute_dtype.value if self._cast_input else x.dtype diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py index 8eb7c26f9b..5bb34cf5bf 100644 --- a/onnxscript/rewriter/ort_fusions/rotary_embedding.py +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -30,24 +30,29 @@ def __init__(self): def pattern(self, op, x, cos, sin, start1, end1, start2, end2): return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin - def check(self, op, x, start1, end1, start2, end2, **_): + def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: - return False + return check_result.fail("Input is not a 4D tensor.", x) if not isinstance(x.shape[1], int): - return False + return check_result.fail("Input dimension 1 is not an integer.", x) head_size = x.shape[3] if not isinstance(head_size, int): - return False + return check_result.fail("Head size is not an integer.", x) half_head_size = head_size // 2 # Check that x is being split into two equal halves of size half_head_size - return ( + if not ( _ir_utils.is_singleton_value(start1, 0) and _ir_utils.is_singleton_value(end1, half_head_size) and _ir_utils.is_singleton_value(start2, half_head_size) and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) - ) + ): + return check_result.fail( + "x is not being split into two equal halves of size half_head_size." + ) + return check_result def rewrite(self, op, x, cos, sin, **_): num_heads = x.shape[1] @@ -69,22 +74,27 @@ def pattern(self, op, x, end1, start2): ) return op.Concat(x_part_1_rope, x_part_2, axis=-1) - def check(self, op, x, end1, start2, x_part_1_rope, **_): + def check(self, op, x, end1, start2, x_part_1_rope, **_) -> pattern.MatchResult: # type: ignore[name-defined] + check_result = pattern.MatchResult() end1_value = _ir_utils.get_singleton_value(end1) start2_value = _ir_utils.get_singleton_value(start2) if not isinstance(end1_value, int) or not isinstance(start2_value, int): - return False + return check_result.fail( + "The end1 value of first slice and start2 value of second slice are not integers." + ) if end1_value != start2_value: - return False + return check_result.fail( + "The end1 value of first slice and start2 value of second slice are not equal." + ) rotary_embedding_attributes = x_part_1_rope.producer().attributes if "rotary_embedding_dim" in rotary_embedding_attributes: - return False + return check_result.fail("rotary_embedding_dim attribute already specified.") if ( "interleaved" in rotary_embedding_attributes and rotary_embedding_attributes["interleaved"].value != 0 ): - return False - return True + return check_result.fail("interleaved is not equal to 0.") + return check_result def rewrite(self, op, x, end1, x_part_1_rope, **_): # Create a modified version of the RotaryEmbedding op: diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 8eefc9aec0..a277f7199f 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -41,14 +41,17 @@ def pattern( return attn_output def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale): + check_result = pattern.MatchResult() # Check that the scaling factors match what SDPA implements: # We need to know the hidden size to check the scaling factors. if query is None or query.shape is None or len(query.shape) < 2: - return False + return check_result.fail( + "Query shape is not known or has less than 2 dimensions.", query + ) hidden_size = query.shape[-1] if not isinstance(hidden_size, int): - return False + return check_result.fail("Hidden size is not an integer.") expected_scaling_factor = math.sqrt(hidden_size) if self._use_mul: expected_scaling_factor = 1.0 / expected_scaling_factor @@ -57,17 +60,26 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, # Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor) sqrt_scaling_factor = math.sqrt(expected_scaling_factor) if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3): - return False + return check_result.fail( + "Query scale is not a scalar or does not match the expected scaling factor.", + query_scale, + ) if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3): - return False + return check_result.fail( + "Key scale is not a scalar or does not match the expected scaling factor.", + key_scale, + ) else: # Check if qk_scale is a scalar == expected_scaling_factor) if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3): - return False + return check_result.fail( + "QK scale is not a scalar or does not match the expected scaling factor.", + qk_scale, + ) # check ranks/shapes - return True + return check_result def rewrite(self, op, query, key_transposed, value, mask, **_): if self._use_mask: @@ -118,6 +130,10 @@ def rewrite(self, op, query, key_transposed, value, mask, **_): ) -def fuse_sdpa(model: ir.Model) -> int: +def fuse_sdpa(model: ir.Model, debug: bool = False) -> int: count = sdpa_rules.apply_to_model(model) + if debug and count == 0: + tracer = pattern.MatchingTracer() + sdpa_rules.apply_to_model(model, tracer=tracer) + tracer.report() return count diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 229c76aab6..1cd79e1c42 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -171,7 +171,7 @@ def test_sdpa_fusion(self, name, script_func): # inputs = test_case.get_ort_inputs() # original_outputs = ort_run("original", model, inputs) - count = fuse_sdpa(model) + count = fuse_sdpa(model, debug=True) self.assertGreater(count, 0) # Check that the fusion was successful diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 6f7e1ea116..793675b4ab 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -330,17 +330,24 @@ def __init__(self) -> None: self.outputs: list[ir.Value] = [] # For a failed match, _reason is a string that describes the reason for the failure. self._reason: str = "" - # Track the node that caused the failure. - # TODO: May be useful to extend this to be a collection of Nodes and Values. - self._failure_node: ir.Node | None = None + # Track the node(s) or value(s) that caused the failure. + self._failure_nodes_and_values: list[Union[ir.Node, ir.Value]] = [] def __bool__(self): return self._success - def fail(self, reason: str = "", node: ir.Node | None = None) -> MatchResult: + def fail( + self, + reason: str = "", + failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None, + ) -> MatchResult: self._success = False self._reason = reason - self._failure_node = node + if failure_source is not None: + if isinstance(failure_source, list): + self._failure_nodes_and_values.extend(failure_source) + else: + self._failure_nodes_and_values.append(failure_source) return self @property @@ -1371,7 +1378,14 @@ def try_rewrite( if var.name is not None: if var.name not in match.bindings: match.bindings[var.name] = None - if not self._condition_function(context, **match.bindings): + check_match_result = self._condition_function(context, **match.bindings) + if not check_match_result: + # If check function was provided, but it failed, return the reason for failure to the tracer. + if isinstance(check_match_result, MatchResult): + match.fail( + check_match_result.reason, + check_match_result._failure_nodes_and_values, + ) if tracer: tracer.log( self, graph_or_function, node, match, MatchStatus.CONDITION_FAILED @@ -1449,8 +1463,8 @@ def rewrite(cls, op, *_) -> Any: raise NotImplementedError("Method 'rewrite' must be overwritten.") @classmethod - def check(cls, context, *_, **__) -> bool: - return True + def check(cls, context, *_, **__) -> bool | MatchResult: + return MatchResult() def make_rewrite_rule_from_class( @@ -1532,8 +1546,9 @@ def pattern(self, op, *args, **kwargs): raise NotImplementedError("Method 'pattern' must be implemented by derived class.") def check(self, op, *args, **kwargs): - # Default check function that always returns True. - return True + # Default check function that returns a + # MatchResult object with success always set to True. + return MatchResult() def rewrite(self, op, *args, **kwargs): raise NotImplementedError("Method 'rewrite' must be implemented by derived class.") @@ -1826,13 +1841,17 @@ def print(self): if self.status != MatchStatus.SUCCESS: reason = self.match_result.reason if reason: - print(f"Graph matching failed: {reason}") + if self.status == MatchStatus.CONDITION_FAILED: + print(f"Graph matching failed due to failing check condition : {reason}") + else: + print(f"Graph matching failed: {reason}") else: print("Graph matching failed.") - failure_node = self.match_result._failure_node - if failure_node: - print("Failure at or around node:") - failure_node.display() + failure_nodes_and_values = self.match_result._failure_nodes_and_values + print("Failure at or around nodes/values:") + if failure_nodes_and_values: + for failure_cause in failure_nodes_and_values: + failure_cause.display() print("Matched nodes:") import onnxscript.rewriter._ir_utils as ir_utils