From 118ce30a6175d0556d9abca33c33afb0effbd7c7 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 26 Mar 2025 18:37:54 +0000 Subject: [PATCH 1/8] add fail statements --- onnxscript/rewriter/generic_pattern_test.py | 4 +- onnxscript/rewriter/llama_rule_sets.py | 102 +++++++++++------- .../rewriter/ort_fusions/cos_sin_cache.py | 20 ++-- .../ort_fusions/fused_matmul_rule_sets.py | 32 +++--- onnxscript/rewriter/ort_fusions/gqa.py | 5 +- onnxscript/rewriter/ort_fusions/mha.py | 42 ++++++-- .../rewriter/ort_fusions/rms_normalization.py | 13 +-- .../rewriter/ort_fusions/rotary_embedding.py | 30 +++--- onnxscript/rewriter/ort_fusions/sdpa.py | 25 +++-- onnxscript/rewriter/ort_fusions/sdpa_test.py | 2 +- onnxscript/rewriter/pattern.py | 20 ++-- 11 files changed, 189 insertions(+), 106 deletions(-) diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index dadaf5e8bb..15e522d8dc 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -546,8 +546,8 @@ def transpose_transpose_mapping(perm0, perm1): # replace by return [perm0[p] for p in perm1] ? return new_perm - def transpose_transpose_check(op, **_) -> bool: - return True + def transpose_transpose_check(op, **_) -> pattern.MatchResult: + return pattern.MatchResult() def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): perm0 = XT.producer().attributes.get("perm") diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index dd8c2aedaf..81961d5883 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -26,9 +26,12 @@ def pattern(self, op, x): def rewrite(self, op, x: ir.Value): return op.Identity(x) - def check(self, context, x) -> bool: + def check(self, context, x) -> orp.MatchResult: del context # Unused - return ir_utils.has_rank(x, 1) + check_result = orp.MatchResult() + if not ir_utils.has_rank(x, 1): + return check_result.fail("Input is not 1D") + return check_result class CastIdentity(orp.RewriteRuleAsClass): @@ -43,8 +46,11 @@ def rewrite(cls, op, x: ir.Value, to: ir.Attr): return op.Identity(x) @classmethod - def check(cls, context, x, to) -> bool: - return x.dtype == to.value + def check(cls, context, x, to) -> orp.MatchResult: + check_result = orp.MatchResult() + if x.dtype != to.value: + return check_result.fail("Input and output types are not the same") + return check_result class CastCast(orp.RewriteRuleAsClass): @@ -62,11 +68,13 @@ def pattern(cls, op, x, to, to_ignored): return op.Cast(op.Cast(x, to=to_ignored), to=to) @classmethod - def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> bool: - return ( - to.value in cls._allowed_tensor_types - and to_ignored.value in cls._allowed_tensor_types - ) + def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult: + check_result = orp.MatchResult() + if to.value not in cls._allowed_tensor_types: + return check_result.fail(f"Output type {to.value} is not allowed") + if to_ignored.value not in cls._allowed_tensor_types: + return check_result.fail(f"Ignored type {to_ignored.value} is not allowed") + return check_result @classmethod def rewrite(cls, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr): @@ -85,14 +93,19 @@ def rewrite(cls, op, x: ir.Value, shape: ir.Value): return op.Identity(x) @classmethod - def check(cls, context, x, shape) -> bool: + def check(cls, context, x, shape) -> orp.MatchResult: + check_result = orp.MatchResult() if shape.const_value is None: # Shape is not a constant and cannot be guessed. - return False + return check_result.fail("Shape is not a constant and cannot be guessed.") if (x_shape := x.shape) is None: # We don't know the shape of the input - return False - return x_shape.dims == tuple(shape.const_value.numpy().tolist()) + return check_result.fail("Input shape is not known.") + if x_shape.dims != tuple(shape.const_value.numpy().tolist()): + return check_result.fail( + f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}." + ) + return check_result class ReshapeReshape(orp.RewriteRuleAsClass): @@ -110,12 +123,15 @@ def rewrite(cls, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value): return op.Reshape(x, shape) @classmethod - def check(cls, context, x, shape_ignored, shape) -> bool: - if shape_ignored.const_value is None or shape.const_value is None: - return False + def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult: + check_result = orp.MatchResult() + if shape_ignored.const_value is None: + return check_result.fail("Shape ignored is not a constant.") + if shape.const_value is None: + return check_result.fail("Shape is not a constant.") if shape.const_value.numpy().min() <= 0: - return False - return True + return check_result.fail("Shape has non-positive values.") + return check_result class SlicesSplit(orp.RewriteRuleAsClass): @@ -128,49 +144,50 @@ def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1): return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1) @classmethod - def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> bool: + def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp: + check_result = orp.MatchResult() if ( axes0.const_value is None or axes1.const_value is None or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist() ): - return False + return check_result.fail("Axes are not equal or not constant.") axes = axes0.const_value.numpy().tolist() if len(axes) != 1: - return False + return check_result.fail("Axes has more than one dimension.") if x.shape: rk = len(x.shape) else: rk = x.rank if axes[0] != -1 and axes[0] != rk - 1: - return False + return check_result.fail("Axes is not -1 or last dimension.") if ( begin0.const_value is None or end0.const_value is None or begin1.const_value is None or end1.const_value is None ): - return False + return check_result.fail("Begin or end are not constant values.") if begin0.const_value.numpy().tolist() != [0]: - return False + return check_result.fail("First begin value is not 0.") e0, b1, e1 = ( end0.const_value.numpy().tolist(), begin1.const_value.numpy().tolist(), end1.const_value.numpy().tolist(), ) if e0[0] != b1[0]: - return False + return check_result.fail("End0 is not equal to Begin1.") shape = x.shape if shape is None: - return False + return check_result.fail("Shape is not known.") last_dim = shape[-1] if not isinstance(last_dim, int): - return False + return check_result.fail("Last dimension is not known.") if last_dim != e1[0]: - return False + return check_result.fail("Last dimension is not equal to End1.") if last_dim // 2 != b1[0]: - return False - return True + return check_result.fail("Last dimension is not equal to Begin1.") + return check_result @classmethod def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1): @@ -187,13 +204,14 @@ def pattern(cls, op, x, perm): return op.Transpose(x, perm=perm) @classmethod - def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: + def check(cls, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult: + check_result = orp.MatchResult() if isinstance(perm, ir.RefAttr): - return False + return check_result.fail("Permutation is not a reference attr.") if perm.type == ir.AttributeType.INTS: if perm.value == list(range(len(perm.value))): - return True - return False + return check_result + return check_result.fail("Permutation is not identity.") @classmethod def rewrite(cls, op, x: ir.Value, perm: ir.Attr): @@ -210,10 +228,11 @@ def pattern(cls, op, x, perm1, perm2): return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2) @classmethod - def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> bool: + def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult: + check_result = orp.MatchResult() if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr): - return False - return True + return check_result.fail("Permutation values are not reference attributes.") + return check_result @classmethod def _apply_transpose(cls, perm: tuple[int, ...], on: list[int]) -> list[int]: @@ -257,17 +276,18 @@ def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value): return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64))) @classmethod - def check(cls, context, x, axes1, axes2) -> bool: + def check(cls, context, x, axes1, axes2) -> orp.MatchResult: + check_result = orp.MatchResult() del context # Unused del x # Unused # Currently restricted to single element positive axis v1 = ir_utils.get_singleton_value(axes1) v2 = ir_utils.get_singleton_value(axes2) if v1 is None or v2 is None: - return False + return check_result.fail("Axes are not constant.") if (v1 < 0) or (v2 < 0): - return False - return True + return check_result.fail("Axes are negative.") + return check_result cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast) diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index 476226c6a2..57e40f9194 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -96,10 +96,16 @@ def pattern( _domain="ai.onnxruntime.fusion", ) - def check(self, context, inv_freq, position_ids, freqs, extra_dims, **_): + def check( + self, context, inv_freq, position_ids, freqs, extra_dims, **_ + ) -> pattern.MatchResult: + check_result = pattern.MatchResult() # TODO(rama): handle redundant reshape/expand if self._const_freqs: - return (freqs.const_value is not None) and _ir_utils.has_rank(freqs, 3) + if (freqs.const_value is None) or not _ir_utils.has_rank(freqs, 3): + return check_result.fail("freqs is not a constant or not 3D.") + else: + return check_result if ( _ir_utils.has_rank(position_ids, 2) and _ir_utils.is_singleton_value(extra_dims, 1) ) or ( @@ -107,13 +113,15 @@ def check(self, context, inv_freq, position_ids, freqs, extra_dims, **_): ): pass else: - return False + return check_result.fail("position_ids are not 1D or 2D tensors.") if not _ir_utils.has_rank(inv_freq, 3): - return False + return check_result.fail("inv_freq is not 3D.") inv_freq_shape = inv_freq.shape if inv_freq.const_value is None: # TODO: should this be inv_freq_shape? - return False - return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 + return check_result.fail("inv_freq is not a constant.") + if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1: + return check_result.fail("inv_freq is not of shape [1, ., 1].") + return check_result def rewrite( self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, **_ diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index 65496ec8bd..69ce879050 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -15,13 +15,14 @@ def pattern(cls, op, x, y, cst): return op.Div(op.MatMul(x, y), cst) @classmethod - def check(cls, context, x, y, cst) -> bool: + def check(cls, context, x, y, cst) -> orp.MatchResult: + check_result = orp.MatchResult() if cst.const_value is None: - return False + return check_result.fail("Divisor is a None value.") value = cst.const_value.numpy() if value.size > 1: - return False - return True + return check_result.fail("Divisor is not a scalar value.") + return check_result @classmethod def rewrite(cls, op, x, y, cst): @@ -38,12 +39,13 @@ def pattern(cls, op, x, y, cst): return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst) @classmethod - def check(cls, context, x, y, cst) -> bool: + def check(cls, context, x, y, cst) -> orp.MatchResult: + check_result = orp.MatchResult() if cst.const_value is None: - return False + return check_result.fail("Divisor is a None value.") if cst.const_value.numpy().size > 1: - return False - return True + return check_result.fail("Divisor is not a scalar value.") + return check_result @classmethod def rewrite(cls, op, x, y, cst): @@ -65,11 +67,14 @@ class _TransposeMatMulBase(orp.RewriteRuleAsClass): _pos: ClassVar = 1 @classmethod - def check(cls, context, x, y) -> bool: + def check(cls, context, x, y) -> orp.MatchResult: + check_result = orp.MatchResult() perm = list((x if cls._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015 expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - return perm == expected_perm + if perm != expected_perm: + return check_result.fail("Permutation values for Transpose are not correct.") + return check_result @classmethod def rewrite(cls, op, x, y): @@ -126,13 +131,16 @@ def pattern(cls, op, x, y): return op.Transpose(op.MatMul(x, y)) @classmethod - def check(cls, context, x, y) -> bool: + def check(cls, context, x, y) -> orp.MatchResult: + check_result = orp.MatchResult() matmul = list(x.uses())[0][0] # noqa: RUF015 transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015 perm = transpose.attributes["perm"].value expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - return perm == expected_perm + if perm != expected_perm: + return check_result.fail("Permutation values for Transpose are not correct.") + return check_result @classmethod def rewrite(cls, op, x, y): diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 4bad28c789..5159ebd2a0 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -94,7 +94,8 @@ def check( # key_transposed, # attention_reshaped, **_, - ): + ) -> pattern.MatchResult: + check_result = pattern.MatchResult() # bindings: dict[str, int] = {} # status = ( # _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) @@ -110,7 +111,7 @@ def check( # return False # if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]: # return False - return True + return check_result def rewrite( self, diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 0563dc4edd..fa2c2dd224 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -163,29 +163,49 @@ def check( key_BSHDh, value_BSHDh, **_, - ): + ) -> pattern.MatchResult: + check_result = pattern.MatchResult() bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _check_shape(bindings, val, dims) + if not _check_shape(bindings, val, dims): + return check_result.fail( + f"Shape mismatch: {val} does not match expected dimensions {dims}" + ) if no_match(query_BSD, ["B", "S", "D"]): - return False + return check_result.fail( + f"Shape mismatch: {query_BSD} does not match expected dimensions ['B', 'S', 'D']" + ) if no_match(key_BSD, ["B", "Skv", "D"]): - return False + return check_result.fail( + f"Shape mismatch: {key_BSD} does not match expected dimensions ['B', 'Skv', 'D']" + ) if no_match(value_BSD, ["B", "Skv", "D"]): - return False + return check_result.fail( + f"Shape mismatch: {value_BSD} does not match expected dimensions ['B', 'Skv', 'D']" + ) if no_match(past_key, ["B", "H", "Spast", "Dh"]): - return False + return check_result.fail( + f"Shape mismatch: {past_key} does not match expected dimensions ['B', 'H', 'Spast', 'Dh']" + ) if no_match(past_value, ["B", "H", "Spast", "Dv"]): - return False + return check_result.fail( + f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']" + ) if no_match(query_BSHDh, ["B", "S", "H", "Dh"]): - return False + return check_result.fail( + f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']" + ) if no_match(key_BSHDh, ["B", "S", "H", "Dh"]): - return False + return check_result.fail( + f"Shape mismatch: {key_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']" + ) if no_match(value_BSHDh, ["B", "S", "H", "Dh"]): - return False + return check_result.fail( + f"Shape mismatch: {value_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']" + ) # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St) # But this also, unforunately, depends on ORT version. @@ -193,7 +213,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: # or check Reshape's shape-input value - return True + return check_result def rewrite( self, diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py index 4cea9d7b90..60145ed284 100644 --- a/onnxscript/rewriter/ort_fusions/rms_normalization.py +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -52,21 +52,22 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): normalized = op.Cast(normalized, to=target_dtype) return op.Mul(scale, normalized) - def check(self, op, x, scale, epsilon, compute_dtype, target_dtype): + def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" + check_result = pattern.MatchResult() # epsilon must be a scalar epsilon_value = _ir_utils.get_singleton_value(epsilon) if not isinstance(epsilon_value, float): # TODO: support other types - return False + return check_result.fail("Epsilon is not a float value.") # input and output must be same dtype if x.dtype not in float_types: - return False + return check_result.fail("Input is not a float type.") if scale.dtype not in float_types: - return False + return check_result.fail("Scale is not a float type.") stash_dtype = compute_dtype.value if self._cast_input else x.dtype if stash_dtype not in fp_float_types: - return False - return True + return check_result.fail("Normalization precision is not a float or double type.") + return check_result def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): stash_dtype = compute_dtype.value if self._cast_input else x.dtype diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py index 8eb7c26f9b..be10c44841 100644 --- a/onnxscript/rewriter/ort_fusions/rotary_embedding.py +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -30,24 +30,29 @@ def __init__(self): def pattern(self, op, x, cos, sin, start1, end1, start2, end2): return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin - def check(self, op, x, start1, end1, start2, end2, **_): + def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: + check_result = pattern.MatchResult() # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: - return False + return check_result.fail("Input is not a 4D tensor.") if not isinstance(x.shape[1], int): - return False + return check_result.fail("Input dimension 1 is not an integer.") head_size = x.shape[3] if not isinstance(head_size, int): - return False + return check_result.fail("Head size is not an integer.") half_head_size = head_size // 2 # Check that x is being split into two equal halves of size half_head_size - return ( + if not ( _ir_utils.is_singleton_value(start1, 0) and _ir_utils.is_singleton_value(end1, half_head_size) and _ir_utils.is_singleton_value(start2, half_head_size) and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) - ) + ): + return check_result.fail( + "x is not being split into two equal halves of size half_head_size." + ) + return check_result def rewrite(self, op, x, cos, sin, **_): num_heads = x.shape[1] @@ -69,22 +74,23 @@ def pattern(self, op, x, end1, start2): ) return op.Concat(x_part_1_rope, x_part_2, axis=-1) - def check(self, op, x, end1, start2, x_part_1_rope, **_): + def check(self, op, x, end1, start2, x_part_1_rope, **_) -> pattern.MatchResult: + check_result = pattern.MatchResult() end1_value = _ir_utils.get_singleton_value(end1) start2_value = _ir_utils.get_singleton_value(start2) if not isinstance(end1_value, int) or not isinstance(start2_value, int): - return False + return check_result.fail("end1 and start2 are not integers.") if end1_value != start2_value: - return False + return check_result.fail("end1 and start2 are not equal.") rotary_embedding_attributes = x_part_1_rope.producer().attributes if "rotary_embedding_dim" in rotary_embedding_attributes: - return False + return check_result.fail("rotary_embedding_dim is not a singleton value.") if ( "interleaved" in rotary_embedding_attributes and rotary_embedding_attributes["interleaved"].value != 0 ): - return False - return True + return check_result.fail("interleaved is not equal to 0.") + return check_result def rewrite(self, op, x, end1, x_part_1_rope, **_): # Create a modified version of the RotaryEmbedding op: diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 8eefc9aec0..5c53e49df8 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -41,14 +41,15 @@ def pattern( return attn_output def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale): + check_result = pattern.MatchResult() # Check that the scaling factors match what SDPA implements: # We need to know the hidden size to check the scaling factors. if query is None or query.shape is None or len(query.shape) < 2: - return False + return check_result.fail("Query shape is not known or has less than 2 dimensions.") hidden_size = query.shape[-1] if not isinstance(hidden_size, int): - return False + return check_result.fail("Hidden size is not an integer.") expected_scaling_factor = math.sqrt(hidden_size) if self._use_mul: expected_scaling_factor = 1.0 / expected_scaling_factor @@ -57,17 +58,23 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, # Check if query_scale and key_scale are scalars == sqrt(expected_scaling_factor) sqrt_scaling_factor = math.sqrt(expected_scaling_factor) if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3): - return False + return check_result.fail( + "Query scale is not a scalar or does not match the expected scaling factor." + ) if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3): - return False + return check_result.fail( + "Key scale is not a scalar or does not match the expected scaling factor." + ) else: # Check if qk_scale is a scalar == expected_scaling_factor) if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3): - return False + return check_result.fail( + "QK scale is not a scalar or does not match the expected scaling factor." + ) # check ranks/shapes - return True + return check_result def rewrite(self, op, query, key_transposed, value, mask, **_): if self._use_mask: @@ -118,6 +125,10 @@ def rewrite(self, op, query, key_transposed, value, mask, **_): ) -def fuse_sdpa(model: ir.Model) -> int: +def fuse_sdpa(model: ir.Model, debug: bool = False) -> int: count = sdpa_rules.apply_to_model(model) + if debug and count == 0: + tracer = pattern.MatchingTracer() + sdpa_rules.apply_to_model(model, tracer=tracer) + tracer.report() return count diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 229c76aab6..1cd79e1c42 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -171,7 +171,7 @@ def test_sdpa_fusion(self, name, script_func): # inputs = test_case.get_ort_inputs() # original_outputs = ort_run("original", model, inputs) - count = fuse_sdpa(model) + count = fuse_sdpa(model, debug=True) self.assertGreater(count, 0) # Check that the fusion was successful diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 6f7e1ea116..c480ecc37a 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1371,7 +1371,11 @@ def try_rewrite( if var.name is not None: if var.name not in match.bindings: match.bindings[var.name] = None - if not self._condition_function(context, **match.bindings): + check_match_result = self._condition_function(context, **match.bindings) + if not check_match_result: + # If check function was provided, but it failed, return the reason for failure to the tracer. + if not isinstance(check_match_result, bool): + match.fail(check_match_result.reason) if tracer: tracer.log( self, graph_or_function, node, match, MatchStatus.CONDITION_FAILED @@ -1449,8 +1453,8 @@ def rewrite(cls, op, *_) -> Any: raise NotImplementedError("Method 'rewrite' must be overwritten.") @classmethod - def check(cls, context, *_, **__) -> bool: - return True + def check(cls, context, *_, **__) -> MatchResult: + return MatchResult() def make_rewrite_rule_from_class( @@ -1532,8 +1536,9 @@ def pattern(self, op, *args, **kwargs): raise NotImplementedError("Method 'pattern' must be implemented by derived class.") def check(self, op, *args, **kwargs): - # Default check function that always returns True. - return True + # Default check function that returns a + # MatchResult object with success always set to True. + return MatchResult() def rewrite(self, op, *args, **kwargs): raise NotImplementedError("Method 'rewrite' must be implemented by derived class.") @@ -1826,7 +1831,10 @@ def print(self): if self.status != MatchStatus.SUCCESS: reason = self.match_result.reason if reason: - print(f"Graph matching failed: {reason}") + if self.status == MatchStatus.CONDITION_FAILED: + print(f"Graph matching failed due to failing check condition : {reason}") + else: + print(f"Graph matching failed: {reason}") else: print("Graph matching failed.") failure_node = self.match_result._failure_node From ac8e98eb42cc1c7519037aaf2eb3edd8336f1359 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 26 Mar 2025 19:04:18 +0000 Subject: [PATCH 2/8] fix mypy errors --- onnxscript/rewriter/llama_rule_sets.py | 2 +- onnxscript/rewriter/ort_fusions/cos_sin_cache.py | 2 +- onnxscript/rewriter/ort_fusions/gqa.py | 2 +- onnxscript/rewriter/ort_fusions/mha.py | 7 ++----- onnxscript/rewriter/ort_fusions/rms_normalization.py | 2 +- onnxscript/rewriter/ort_fusions/rotary_embedding.py | 4 ++-- 6 files changed, 8 insertions(+), 11 deletions(-) diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 81961d5883..4a5cae30d7 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -144,7 +144,7 @@ def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1): return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1) @classmethod - def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp: + def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult: check_result = orp.MatchResult() if ( axes0.const_value is None diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index 57e40f9194..20cdb7140d 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -98,7 +98,7 @@ def pattern( def check( self, context, inv_freq, position_ids, freqs, extra_dims, **_ - ) -> pattern.MatchResult: + ) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() # TODO(rama): handle redundant reshape/expand if self._const_freqs: diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 5159ebd2a0..b57519ad17 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -94,7 +94,7 @@ def check( # key_transposed, # attention_reshaped, **_, - ) -> pattern.MatchResult: + ) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() # bindings: dict[str, int] = {} # status = ( diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index fa2c2dd224..e1e82be4d9 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -163,15 +163,12 @@ def check( key_BSHDh, value_BSHDh, **_, - ) -> pattern.MatchResult: + ) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - if not _check_shape(bindings, val, dims): - return check_result.fail( - f"Shape mismatch: {val} does not match expected dimensions {dims}" - ) + return not _check_shape(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py index 60145ed284..ae5404b47d 100644 --- a/onnxscript/rewriter/ort_fusions/rms_normalization.py +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -52,7 +52,7 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): normalized = op.Cast(normalized, to=target_dtype) return op.Mul(scale, normalized) - def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: + def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: # type: ignore[name-defined] """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" check_result = pattern.MatchResult() # epsilon must be a scalar diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py index be10c44841..b53d199868 100644 --- a/onnxscript/rewriter/ort_fusions/rotary_embedding.py +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -30,7 +30,7 @@ def __init__(self): def pattern(self, op, x, cos, sin, start1, end1, start2, end2): return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin - def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: + def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: @@ -74,7 +74,7 @@ def pattern(self, op, x, end1, start2): ) return op.Concat(x_part_1_rope, x_part_2, axis=-1) - def check(self, op, x, end1, start2, x_part_1_rope, **_) -> pattern.MatchResult: + def check(self, op, x, end1, start2, x_part_1_rope, **_) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() end1_value = _ir_utils.get_singleton_value(end1) start2_value = _ir_utils.get_singleton_value(start2) From add1284fa4260accb60ac5ae27b036b5d31ee369 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 26 Mar 2025 21:57:54 +0000 Subject: [PATCH 3/8] extend fail to specify failure cause --- onnxscript/rewriter/llama_rule_sets.py | 4 +- .../rewriter/ort_fusions/cos_sin_cache.py | 2 +- .../ort_fusions/fused_matmul_rule_sets.py | 4 +- .../rewriter/ort_fusions/rotary_embedding.py | 10 +++-- onnxscript/rewriter/ort_fusions/sdpa.py | 13 ++++-- onnxscript/rewriter/pattern.py | 43 +++++++++++++------ 6 files changed, 52 insertions(+), 24 deletions(-) diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 4a5cae30d7..f721bf5c9e 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -207,7 +207,7 @@ def pattern(cls, op, x, perm): def check(cls, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult: check_result = orp.MatchResult() if isinstance(perm, ir.RefAttr): - return check_result.fail("Permutation is not a reference attr.") + return check_result.fail("Permutation is a reference attribute.") if perm.type == ir.AttributeType.INTS: if perm.value == list(range(len(perm.value))): return check_result @@ -231,7 +231,7 @@ def pattern(cls, op, x, perm1, perm2): def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult: check_result = orp.MatchResult() if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr): - return check_result.fail("Permutation values are not reference attributes.") + return check_result.fail("Permutation is a reference attribute.") return check_result @classmethod diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index 20cdb7140d..f693773e73 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -113,7 +113,7 @@ def check( ): pass else: - return check_result.fail("position_ids are not 1D or 2D tensors.") + return check_result.fail("position_ids is not a 1D or 2D tensor.") if not _ir_utils.has_rank(inv_freq, 3): return check_result.fail("inv_freq is not 3D.") inv_freq_shape = inv_freq.shape diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index 69ce879050..d60d8ad300 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -18,7 +18,7 @@ def pattern(cls, op, x, y, cst): def check(cls, context, x, y, cst) -> orp.MatchResult: check_result = orp.MatchResult() if cst.const_value is None: - return check_result.fail("Divisor is a None value.") + return check_result.fail("Divisor is not a constant value.") value = cst.const_value.numpy() if value.size > 1: return check_result.fail("Divisor is not a scalar value.") @@ -42,7 +42,7 @@ def pattern(cls, op, x, y, cst): def check(cls, context, x, y, cst) -> orp.MatchResult: check_result = orp.MatchResult() if cst.const_value is None: - return check_result.fail("Divisor is a None value.") + return check_result.fail("Divisor is not a constant value.") if cst.const_value.numpy().size > 1: return check_result.fail("Divisor is not a scalar value.") return check_result diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py index b53d199868..8053da8d41 100644 --- a/onnxscript/rewriter/ort_fusions/rotary_embedding.py +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -79,12 +79,16 @@ def check(self, op, x, end1, start2, x_part_1_rope, **_) -> pattern.MatchResult: end1_value = _ir_utils.get_singleton_value(end1) start2_value = _ir_utils.get_singleton_value(start2) if not isinstance(end1_value, int) or not isinstance(start2_value, int): - return check_result.fail("end1 and start2 are not integers.") + return check_result.fail( + "The end1 value of first slice and start2 value of second slice are not integers." + ) if end1_value != start2_value: - return check_result.fail("end1 and start2 are not equal.") + return check_result.fail( + "The end1 value of first slice and start2 value of second slice are not equal." + ) rotary_embedding_attributes = x_part_1_rope.producer().attributes if "rotary_embedding_dim" in rotary_embedding_attributes: - return check_result.fail("rotary_embedding_dim is not a singleton value.") + return check_result.fail("rotary_embedding_dim attribute already specified.") if ( "interleaved" in rotary_embedding_attributes and rotary_embedding_attributes["interleaved"].value != 0 diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 5c53e49df8..a277f7199f 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -46,7 +46,9 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, # We need to know the hidden size to check the scaling factors. if query is None or query.shape is None or len(query.shape) < 2: - return check_result.fail("Query shape is not known or has less than 2 dimensions.") + return check_result.fail( + "Query shape is not known or has less than 2 dimensions.", query + ) hidden_size = query.shape[-1] if not isinstance(hidden_size, int): return check_result.fail("Hidden size is not an integer.") @@ -59,17 +61,20 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, sqrt_scaling_factor = math.sqrt(expected_scaling_factor) if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3): return check_result.fail( - "Query scale is not a scalar or does not match the expected scaling factor." + "Query scale is not a scalar or does not match the expected scaling factor.", + query_scale, ) if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3): return check_result.fail( - "Key scale is not a scalar or does not match the expected scaling factor." + "Key scale is not a scalar or does not match the expected scaling factor.", + key_scale, ) else: # Check if qk_scale is a scalar == expected_scaling_factor) if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3): return check_result.fail( - "QK scale is not a scalar or does not match the expected scaling factor." + "QK scale is not a scalar or does not match the expected scaling factor.", + qk_scale, ) # check ranks/shapes diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index c480ecc37a..08ccffb960 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -330,17 +330,24 @@ def __init__(self) -> None: self.outputs: list[ir.Value] = [] # For a failed match, _reason is a string that describes the reason for the failure. self._reason: str = "" - # Track the node that caused the failure. - # TODO: May be useful to extend this to be a collection of Nodes and Values. - self._failure_node: ir.Node | None = None + # Track the node(s) or value(s) that caused the failure. + self._failure_nodes_and_values: list[Union[ir.Node, ir.Value]] = [] def __bool__(self): return self._success - def fail(self, reason: str = "", node: ir.Node | None = None) -> MatchResult: + def fail( + self, + reason: str = "", + failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None, + ) -> MatchResult: self._success = False self._reason = reason - self._failure_node = node + if failure_source is not None: + if isinstance(failure_source, list): + self._failure_nodes_and_values.extend(failure_source) + else: + self._failure_nodes_and_values.append(failure_source) return self @property @@ -1374,8 +1381,10 @@ def try_rewrite( check_match_result = self._condition_function(context, **match.bindings) if not check_match_result: # If check function was provided, but it failed, return the reason for failure to the tracer. - if not isinstance(check_match_result, bool): - match.fail(check_match_result.reason) + if isinstance(check_match_result, MatchResult): + match.fail( + check_match_result.reason, check_match_result._failure_nodes_and_values + ) if tracer: tracer.log( self, graph_or_function, node, match, MatchStatus.CONDITION_FAILED @@ -1453,7 +1462,7 @@ def rewrite(cls, op, *_) -> Any: raise NotImplementedError("Method 'rewrite' must be overwritten.") @classmethod - def check(cls, context, *_, **__) -> MatchResult: + def check(cls, context, *_, **__) -> bool | MatchResult: return MatchResult() @@ -1837,10 +1846,20 @@ def print(self): print(f"Graph matching failed: {reason}") else: print("Graph matching failed.") - failure_node = self.match_result._failure_node - if failure_node: - print("Failure at or around node:") - failure_node.display() + failure_nodes_and_values = self.match_result._failure_nodes_and_values + print("Failure at or around nodes/values:") + + if failure_nodes_and_values: + if isinstance(failure_nodes_and_values, list): + for failure_cause in failure_nodes_and_values: + if isinstance(failure_cause, ir.Node): + failure_cause.display() + elif isinstance(failure_cause, ir.Value): + print(failure_cause) + elif isinstance(failure_nodes_and_values, ir.Node): + failure_nodes_and_values.display() + elif isinstance(failure_nodes_and_values, ir.Value): + print(failure_nodes_and_values) print("Matched nodes:") import onnxscript.rewriter._ir_utils as ir_utils From 6790795aed037c8290587651b7bbf302df0cfb5b Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 26 Mar 2025 22:03:07 +0000 Subject: [PATCH 4/8] fix empty case --- onnxscript/rewriter/pattern.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 08ccffb960..d07f9e91c7 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1382,9 +1382,12 @@ def try_rewrite( if not check_match_result: # If check function was provided, but it failed, return the reason for failure to the tracer. if isinstance(check_match_result, MatchResult): - match.fail( - check_match_result.reason, check_match_result._failure_nodes_and_values - ) + if check_match_result._failure_nodes_and_values: + match.fail( + check_match_result.reason, + check_match_result._failure_nodes_and_values, + ) + match.fail(check_match_result.reason) if tracer: tracer.log( self, graph_or_function, node, match, MatchStatus.CONDITION_FAILED From d3b69c369e90a4044dec597622c2ba595105c9a1 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 26 Mar 2025 22:04:50 +0000 Subject: [PATCH 5/8] minor fix --- onnxscript/rewriter/pattern.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d07f9e91c7..dcf993592a 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1387,7 +1387,8 @@ def try_rewrite( check_match_result.reason, check_match_result._failure_nodes_and_values, ) - match.fail(check_match_result.reason) + else: + match.fail(check_match_result.reason) if tracer: tracer.log( self, graph_or_function, node, match, MatchStatus.CONDITION_FAILED From d841ea261375e9380a4971968a670802bc5d50e2 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 27 Mar 2025 17:16:05 +0000 Subject: [PATCH 6/8] remove unnecessary if else --- onnxscript/rewriter/pattern.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index dcf993592a..793675b4ab 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1382,13 +1382,10 @@ def try_rewrite( if not check_match_result: # If check function was provided, but it failed, return the reason for failure to the tracer. if isinstance(check_match_result, MatchResult): - if check_match_result._failure_nodes_and_values: - match.fail( - check_match_result.reason, - check_match_result._failure_nodes_and_values, - ) - else: - match.fail(check_match_result.reason) + match.fail( + check_match_result.reason, + check_match_result._failure_nodes_and_values, + ) if tracer: tracer.log( self, graph_or_function, node, match, MatchStatus.CONDITION_FAILED @@ -1852,18 +1849,9 @@ def print(self): print("Graph matching failed.") failure_nodes_and_values = self.match_result._failure_nodes_and_values print("Failure at or around nodes/values:") - if failure_nodes_and_values: - if isinstance(failure_nodes_and_values, list): - for failure_cause in failure_nodes_and_values: - if isinstance(failure_cause, ir.Node): - failure_cause.display() - elif isinstance(failure_cause, ir.Value): - print(failure_cause) - elif isinstance(failure_nodes_and_values, ir.Node): - failure_nodes_and_values.display() - elif isinstance(failure_nodes_and_values, ir.Value): - print(failure_nodes_and_values) + for failure_cause in failure_nodes_and_values: + failure_cause.display() print("Matched nodes:") import onnxscript.rewriter._ir_utils as ir_utils From 3010dac4e6942491b04655be8ac4a1de9e8c5714 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 27 Mar 2025 17:41:11 +0000 Subject: [PATCH 7/8] add failure nodes to some checks --- .../rewriter/ort_fusions/cos_sin_cache.py | 10 ++++---- onnxscript/rewriter/ort_fusions/mha.py | 24 ++++++++++++------- .../rewriter/ort_fusions/rms_normalization.py | 6 ++--- .../rewriter/ort_fusions/rotary_embedding.py | 6 ++--- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index f693773e73..cf0522c5ad 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -103,7 +103,7 @@ def check( # TODO(rama): handle redundant reshape/expand if self._const_freqs: if (freqs.const_value is None) or not _ir_utils.has_rank(freqs, 3): - return check_result.fail("freqs is not a constant or not 3D.") + return check_result.fail("freqs is not a constant or not 3D.", freqs) else: return check_result if ( @@ -113,14 +113,14 @@ def check( ): pass else: - return check_result.fail("position_ids is not a 1D or 2D tensor.") + return check_result.fail("position_ids is not a 1D or 2D tensor.", position_ids) if not _ir_utils.has_rank(inv_freq, 3): - return check_result.fail("inv_freq is not 3D.") + return check_result.fail("inv_freq is not 3D.", inv_freq) inv_freq_shape = inv_freq.shape if inv_freq.const_value is None: # TODO: should this be inv_freq_shape? - return check_result.fail("inv_freq is not a constant.") + return check_result.fail("inv_freq is not a constant.", inv_freq) if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1: - return check_result.fail("inv_freq is not of shape [1, ., 1].") + return check_result.fail("inv_freq is not of shape [1, ., 1].", inv_freq) return check_result def rewrite( diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index e1e82be4d9..8bb85f2aed 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -172,36 +172,44 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: if no_match(query_BSD, ["B", "S", "D"]): return check_result.fail( - f"Shape mismatch: {query_BSD} does not match expected dimensions ['B', 'S', 'D']" + f"Shape mismatch: {query_BSD} does not match expected dimensions ['B', 'S', 'D']", + query_BSD, ) if no_match(key_BSD, ["B", "Skv", "D"]): return check_result.fail( - f"Shape mismatch: {key_BSD} does not match expected dimensions ['B', 'Skv', 'D']" + f"Shape mismatch: {key_BSD} does not match expected dimensions ['B', 'Skv', 'D']", + query_BSD, ) if no_match(value_BSD, ["B", "Skv", "D"]): return check_result.fail( - f"Shape mismatch: {value_BSD} does not match expected dimensions ['B', 'Skv', 'D']" + f"Shape mismatch: {value_BSD} does not match expected dimensions ['B', 'Skv', 'D']", + value_BSD, ) if no_match(past_key, ["B", "H", "Spast", "Dh"]): return check_result.fail( - f"Shape mismatch: {past_key} does not match expected dimensions ['B', 'H', 'Spast', 'Dh']" + f"Shape mismatch: {past_key} does not match expected dimensions ['B', 'H', 'Spast', 'Dh']", + past_key, ) if no_match(past_value, ["B", "H", "Spast", "Dv"]): return check_result.fail( - f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']" + f"Shape mismatch: {past_value} does not match expected dimensions ['B', 'H', 'Spast', 'Dv']", + past_value, ) if no_match(query_BSHDh, ["B", "S", "H", "Dh"]): return check_result.fail( - f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']" + f"Shape mismatch: {query_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']", + query_BSHDh, ) if no_match(key_BSHDh, ["B", "S", "H", "Dh"]): return check_result.fail( - f"Shape mismatch: {key_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']" + f"Shape mismatch: {key_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']", + query_BSHDh, ) if no_match(value_BSHDh, ["B", "S", "H", "Dh"]): return check_result.fail( - f"Shape mismatch: {value_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']" + f"Shape mismatch: {value_BSHDh} does not match expected dimensions ['B', 'S', 'H', 'Dh']", + query_BSHDh, ) # TODO: mask shape check: ideally, it should be (1 or B, 1 or H, S, St) # But this also, unforunately, depends on ORT version. diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py index ae5404b47d..55b7f190b2 100644 --- a/onnxscript/rewriter/ort_fusions/rms_normalization.py +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -58,12 +58,12 @@ def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.M # epsilon must be a scalar epsilon_value = _ir_utils.get_singleton_value(epsilon) if not isinstance(epsilon_value, float): # TODO: support other types - return check_result.fail("Epsilon is not a float value.") + return check_result.fail("Epsilon is not a float value.", epsilon) # input and output must be same dtype if x.dtype not in float_types: - return check_result.fail("Input is not a float type.") + return check_result.fail("Input is not a float type.", x) if scale.dtype not in float_types: - return check_result.fail("Scale is not a float type.") + return check_result.fail("Scale is not a float type.", scale) stash_dtype = compute_dtype.value if self._cast_input else x.dtype if stash_dtype not in fp_float_types: return check_result.fail("Normalization precision is not a float or double type.") diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py index 8053da8d41..5bb34cf5bf 100644 --- a/onnxscript/rewriter/ort_fusions/rotary_embedding.py +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -34,12 +34,12 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: check_result = pattern.MatchResult() # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: - return check_result.fail("Input is not a 4D tensor.") + return check_result.fail("Input is not a 4D tensor.", x) if not isinstance(x.shape[1], int): - return check_result.fail("Input dimension 1 is not an integer.") + return check_result.fail("Input dimension 1 is not an integer.", x) head_size = x.shape[3] if not isinstance(head_size, int): - return check_result.fail("Head size is not an integer.") + return check_result.fail("Head size is not an integer.", x) half_head_size = head_size // 2 # Check that x is being split into two equal halves of size half_head_size From 8748a514e3a62407173c1eb0e6ecac13f5c67ff0 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Fri, 28 Mar 2025 23:14:25 +0000 Subject: [PATCH 8/8] fix one pattern --- onnxscript/rewriter/generic_pattern_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index 15e522d8dc..dadaf5e8bb 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -546,8 +546,8 @@ def transpose_transpose_mapping(perm0, perm1): # replace by return [perm0[p] for p in perm1] ? return new_perm - def transpose_transpose_check(op, **_) -> pattern.MatchResult: - return pattern.MatchResult() + def transpose_transpose_check(op, **_) -> bool: + return True def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): perm0 = XT.producer().attributes.get("perm")