From a099668928123f09c3d9cfdd815302c83ffb3e62 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 7 Jun 2025 08:33:52 -0700 Subject: [PATCH 1/2] Cast-cast elimination --- onnxscript/rewriter/_fusion_utils.py | 6 +-- onnxscript/rewriter/_rewrite_rule.py | 10 +++++ onnxscript/rewriter/llama_rule_sets.py | 30 ++++++++----- onnxscript/rewriter/llama_rule_sets_test.py | 47 +++++++++------------ 4 files changed, 53 insertions(+), 40 deletions(-) diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index b3f298a0f3..0691f9d7de 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -53,14 +53,14 @@ def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> C """ def apply_to( - model: ir.Model, debug: bool = False, apply_shape_inference: bool = False + model: ir.Model, debug: bool = False, apply_shape_inference: bool = False, **kwargs ) -> int: - count = rules.apply_to_model(model) + count = rules.apply_to_model(model, **kwargs) if apply_shape_inference: common_passes.ShapeInferencePass()(model) if count == 0 and debug: tracer = pattern.MatchingTracer() - rules.apply_to_model(model, tracer=tracer) + rules.apply_to_model(model, tracer=tracer, **kwargs) tracer.report() return count diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 33f2aee8a5..3e910edd52 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -15,6 +15,7 @@ import onnxscript.optimizer import onnxscript.rewriter._basics as _basics +import onnxscript.rewriter._ir_utils as _ir_utils import onnxscript.rewriter._matcher as _matcher import onnxscript.rewriter._pattern_ir as _pattern_ir from onnxscript import ir @@ -529,6 +530,15 @@ def _apply_to_graph_or_function( ) f = ir.Function(domain, name, overload, graph=graph, attributes=()) model.functions[f.identifier()] = f + + if verbose: + name = f"{rule.name}: " if rule.name else "" + print(f"----{name}Matched Nodes----") + _ir_utils.display_nodes(delta.match.nodes) + print("++++Replacement Nodes++++") + _ir_utils.display_nodes(delta.new_nodes) + print("++++End Replacement Nodes++++") + convenience.replace_nodes_and_values( graph_or_function, node, diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index 0021739dfe..fa12486092 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -51,22 +51,30 @@ def check(self, context, x, to) -> orp.MatchResult: class CastCast(orp.RewriteRuleClassBase): """Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``.""" - _allowed_tensor_types: ClassVar = { - ir.DataType.FLOAT, - ir.DataType.FLOAT16, - ir.DataType.BFLOAT16, - ir.DataType.DOUBLE, - } + # Simplify "cast type1 => type2 => type3" to "cast type1 => type3". + # This rule is not valid for all combinations of types: e.g., + # it is not valid for float32 => float16 => float32 or float32 => int32 => string. + # TODO: fill out the list of allowed combinations: the following is just a couple + # that shows up in practice where it is valid + _allowed_type2_type3: ClassVar = frozenset( + { + (ir.DataType.FLOAT, ir.DataType.FLOAT16), + (ir.DataType.FLOAT, ir.DataType.BFLOAT16), + } + ) def pattern(self, op, x, to, to_ignored): return op.Cast(op.Cast(x, to=to_ignored), to=to) def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult: check_result = orp.MatchResult() - if to.as_int() not in self._allowed_tensor_types: - return check_result.fail(f"Output type {to.as_int()} is not allowed") - if to_ignored.as_int() not in self._allowed_tensor_types: - return check_result.fail(f"Ignored type {to_ignored.as_int()} is not allowed") + type2 = to_ignored.as_int() + type3 = to.as_int() + if (type2, type3) not in self._allowed_type2_type3: + return check_result.fail( + f"Intermediate cast elimination not recognized as valid from {type2} to {type3}. " + f"Cast-Cast rule may be incomplete for this combination." + ) return check_result def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr): @@ -284,7 +292,7 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet: """ return orp.RewriteRuleSet( [ - # cast_cast_rule, # Might have precision issues. + cast_cast_rule, cast_identity_rule, expand_identity_rule, reshape_reshape_rule, diff --git a/onnxscript/rewriter/llama_rule_sets_test.py b/onnxscript/rewriter/llama_rule_sets_test.py index 29bbcb6004..f256c0dbfa 100644 --- a/onnxscript/rewriter/llama_rule_sets_test.py +++ b/onnxscript/rewriter/llama_rule_sets_test.py @@ -133,41 +133,36 @@ def test_llama_p0_rule_set_transpose_transpose(self, _: str, model: ir.Model): self.assertEqual(["Transpose"], [n.op_type for n in model.graph]) self._check_model(model_proto, rewritten_model) + def _double_cast_model(self, ostype1, ostype2, ostype3): + dtype2 = ostype2.dtype + dtype3 = ostype3.dtype + + @onnxscript.script() + def cast_cast_model(x): + intermediate = opset18.Cast(x, to=dtype2) + y = opset18.Cast(intermediate, to=dtype3) + return y + + return cast_cast_model.to_model_proto( + input_types=[ostype1[10]], output_types=[ostype3[10]] + ) + @parameterized.parameterized.expand( [ - ( - "double_casts", - _make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node( - "Cast", ["X"], ["Xc"], to=onnx.TensorProto.FLOAT16 - ), - onnx.helper.make_node( - "Cast", ["Xc"], ["Y"], to=onnx.TensorProto.DOUBLE - ), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])], - [ - onnx.helper.make_tensor_value_info( - "Y", onnx.TensorProto.DOUBLE, [None, None, None] - ) - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ), - ), + ("float16_float_float16", ot.FLOAT16, ot.FLOAT, ot.FLOAT16), ] ) - def test_llama_p0_rule_set_cast_cast(self, _: str, model: ir.Model): + def test_llama_p0_rule_set_cast_cast(self, _: str, type1, type2, type3): rule_set = llama_rule_sets.cast_cast_rule - model_proto = ir.serde.serialize_model(model) + model_proto = self._double_cast_model(type1, type2, type3) + model = ir.serde.deserialize_model(model_proto) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) self.assertEqual(["Cast"], [n.op_type for n in model.graph]) - self._check_model(model_proto, rewritten_model, atol=1e-2) + # TODO: (random) fp16 inputs + # self._check_model(model_proto, rewritten_model, atol=1e-2) + del rewritten_model # to avoid unused variable warning @parameterized.parameterized.expand( [ From f20aa127fce7e2ce530b1a76da9c68ecbdd88de6 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 7 Jun 2025 08:43:26 -0700 Subject: [PATCH 2/2] Pass on extra kwargs in fuse method Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 78a74f0e03..dd1c79b1fc 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -74,8 +74,8 @@ def fuse_xformers(model: ir.Model, debug: bool = False) -> tuple[ir.Model, dict[ model = _pre_optimize(model) - def fuse(func, apply_shape_inference: bool = False): - return func(model, debug=debug, apply_shape_inference=apply_shape_inference) + def fuse(func, **kwargs): + return func(model, debug=debug, **kwargs) fusion_count["erf_gelu"] = fuse(fuse_erfgelu) fusion_count["rms_normalization"] = fuse(fuse_rms_normalization)