Skip to content

Cast-cast elimination #2368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onnxscript/rewriter/_fusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@
"""

def apply_to(
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False, **kwargs
) -> int:
count = rules.apply_to_model(model)
count = rules.apply_to_model(model, **kwargs)
if apply_shape_inference:
common_passes.ShapeInferencePass()(model)
if count == 0 and debug:
tracer = pattern.MatchingTracer()
rules.apply_to_model(model, tracer=tracer)
rules.apply_to_model(model, tracer=tracer, **kwargs)

Check warning on line 63 in onnxscript/rewriter/_fusion_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_fusion_utils.py#L63

Added line #L63 was not covered by tests
tracer.report()
return count

Expand Down
10 changes: 10 additions & 0 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import onnxscript.optimizer
import onnxscript.rewriter._basics as _basics
import onnxscript.rewriter._ir_utils as _ir_utils
import onnxscript.rewriter._matcher as _matcher
import onnxscript.rewriter._pattern_ir as _pattern_ir
from onnxscript import ir
Expand Down Expand Up @@ -529,6 +530,15 @@
)
f = ir.Function(domain, name, overload, graph=graph, attributes=())
model.functions[f.identifier()] = f

if verbose:
name = f"{rule.name}: " if rule.name else ""
print(f"----{name}Matched Nodes----")
_ir_utils.display_nodes(delta.match.nodes)
print("++++Replacement Nodes++++")
_ir_utils.display_nodes(delta.new_nodes)
print("++++End Replacement Nodes++++")

Check warning on line 540 in onnxscript/rewriter/_rewrite_rule.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_rewrite_rule.py#L535-L540

Added lines #L535 - L540 were not covered by tests

convenience.replace_nodes_and_values(
graph_or_function,
node,
Expand Down
30 changes: 19 additions & 11 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,30 @@
class CastCast(orp.RewriteRuleClassBase):
"""Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""

_allowed_tensor_types: ClassVar = {
ir.DataType.FLOAT,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
ir.DataType.DOUBLE,
}
# Simplify "cast type1 => type2 => type3" to "cast type1 => type3".
# This rule is not valid for all combinations of types: e.g.,
# it is not valid for float32 => float16 => float32 or float32 => int32 => string.
# TODO: fill out the list of allowed combinations: the following is just a couple
# that shows up in practice where it is valid
_allowed_type2_type3: ClassVar = frozenset(
{
(ir.DataType.FLOAT, ir.DataType.FLOAT16),
(ir.DataType.FLOAT, ir.DataType.BFLOAT16),
}
)

def pattern(self, op, x, to, to_ignored):
return op.Cast(op.Cast(x, to=to_ignored), to=to)

def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
check_result = orp.MatchResult()
if to.as_int() not in self._allowed_tensor_types:
return check_result.fail(f"Output type {to.as_int()} is not allowed")
if to_ignored.as_int() not in self._allowed_tensor_types:
return check_result.fail(f"Ignored type {to_ignored.as_int()} is not allowed")
type2 = to_ignored.as_int()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
type2 = to_ignored.as_int()
type2 = it.DataType(to_ignored.as_int())

type3 = to.as_int()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
type3 = to.as_int()
type3 = ir.DataType(to.as_int())

if (type2, type3) not in self._allowed_type2_type3:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe when type2==type3 or when both .is_floating_point() and type2.itemsize > type3.itemsize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first one is already covered by redundant cast elimination. The second one is what I think too (at least for float types; something similar for int types too may hold), but am keeping it simple here. (Do these methods is_floating_point / itemsize exist in ir.DataType?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the methods exist. I think it’s pretty simple to implement

return check_result.fail(

Check warning on line 74 in onnxscript/rewriter/llama_rule_sets.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets.py#L74

Added line #L74 was not covered by tests
f"Intermediate cast elimination not recognized as valid from {type2} to {type3}. "
f"Cast-Cast rule may be incomplete for this combination."
)
return check_result

def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
Expand Down Expand Up @@ -284,7 +292,7 @@
"""
return orp.RewriteRuleSet(
[
# cast_cast_rule, # Might have precision issues.
cast_cast_rule,
cast_identity_rule,
expand_identity_rule,
reshape_reshape_rule,
Expand Down
47 changes: 21 additions & 26 deletions onnxscript/rewriter/llama_rule_sets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,41 +133,36 @@
self.assertEqual(["Transpose"], [n.op_type for n in model.graph])
self._check_model(model_proto, rewritten_model)

def _double_cast_model(self, ostype1, ostype2, ostype3):
dtype2 = ostype2.dtype
dtype3 = ostype3.dtype

@onnxscript.script()
def cast_cast_model(x):
intermediate = opset18.Cast(x, to=dtype2)
y = opset18.Cast(intermediate, to=dtype3)
return y

Check warning on line 144 in onnxscript/rewriter/llama_rule_sets_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/llama_rule_sets_test.py#L142-L144

Added lines #L142 - L144 were not covered by tests

return cast_cast_model.to_model_proto(
input_types=[ostype1[10]], output_types=[ostype3[10]]
)

@parameterized.parameterized.expand(
[
(
"double_casts",
_make_model(
onnx.helper.make_graph(
[
onnx.helper.make_node(
"Cast", ["X"], ["Xc"], to=onnx.TensorProto.FLOAT16
),
onnx.helper.make_node(
"Cast", ["Xc"], ["Y"], to=onnx.TensorProto.DOUBLE
),
],
"name",
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
[
onnx.helper.make_tensor_value_info(
"Y", onnx.TensorProto.DOUBLE, [None, None, None]
)
],
),
opset_imports=[onnx.helper.make_opsetid("", 18)],
),
),
("float16_float_float16", ot.FLOAT16, ot.FLOAT, ot.FLOAT16),
]
)
def test_llama_p0_rule_set_cast_cast(self, _: str, model: ir.Model):
def test_llama_p0_rule_set_cast_cast(self, _: str, type1, type2, type3):
rule_set = llama_rule_sets.cast_cast_rule
model_proto = ir.serde.serialize_model(model)
model_proto = self._double_cast_model(type1, type2, type3)
model = ir.serde.deserialize_model(model_proto)
rule_set.apply_to_model(model)
rewritten_model = ir.serde.serialize_model(model)

self.assertEqual(["Cast"], [n.op_type for n in model.graph])
self._check_model(model_proto, rewritten_model, atol=1e-2)
# TODO: (random) fp16 inputs
# self._check_model(model_proto, rewritten_model, atol=1e-2)
del rewritten_model # to avoid unused variable warning

@parameterized.parameterized.expand(
[
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def fuse_xformers(model: ir.Model, debug: bool = False) -> tuple[ir.Model, dict[

model = _pre_optimize(model)

def fuse(func, apply_shape_inference: bool = False):
return func(model, debug=debug, apply_shape_inference=apply_shape_inference)
def fuse(func, **kwargs):
return func(model, debug=debug, **kwargs)

fusion_count["erf_gelu"] = fuse(fuse_erfgelu)
fusion_count["rms_normalization"] = fuse(fuse_rms_normalization)
Expand Down
Loading