Skip to content

Commit cabd83b

Browse files
authored
Cast-cast elimination (#2368)
Enable the cast-cast simplification to a single cast in a couple of cases where it is valid. This shows up in examples like SmolLM (FP16) and is needed for fusion-pattern to work. Also: add display of replaced and replacing nodes in fusion in verbose mode. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent dcb773f commit cabd83b

File tree

5 files changed

+55
-42
lines changed

5 files changed

+55
-42
lines changed

onnxscript/rewriter/_fusion_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> C
5353
"""
5454

5555
def apply_to(
56-
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False
56+
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False, **kwargs
5757
) -> int:
58-
count = rules.apply_to_model(model)
58+
count = rules.apply_to_model(model, **kwargs)
5959
if apply_shape_inference:
6060
common_passes.ShapeInferencePass()(model)
6161
if count == 0 and debug:
6262
tracer = pattern.MatchingTracer()
63-
rules.apply_to_model(model, tracer=tracer)
63+
rules.apply_to_model(model, tracer=tracer, **kwargs)
6464
tracer.report()
6565
return count
6666

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import onnxscript.optimizer
1717
import onnxscript.rewriter._basics as _basics
18+
import onnxscript.rewriter._ir_utils as _ir_utils
1819
import onnxscript.rewriter._matcher as _matcher
1920
import onnxscript.rewriter._pattern_ir as _pattern_ir
2021
from onnxscript import ir
@@ -529,6 +530,15 @@ def _apply_to_graph_or_function(
529530
)
530531
f = ir.Function(domain, name, overload, graph=graph, attributes=())
531532
model.functions[f.identifier()] = f
533+
534+
if verbose:
535+
name = f"{rule.name}: " if rule.name else ""
536+
print(f"----{name}Matched Nodes----")
537+
_ir_utils.display_nodes(delta.match.nodes)
538+
print("++++Replacement Nodes++++")
539+
_ir_utils.display_nodes(delta.new_nodes)
540+
print("++++End Replacement Nodes++++")
541+
532542
convenience.replace_nodes_and_values(
533543
graph_or_function,
534544
node,

onnxscript/rewriter/llama_rule_sets.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,30 @@ def check(self, context, x, to) -> orp.MatchResult:
5151
class CastCast(orp.RewriteRuleClassBase):
5252
"""Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``."""
5353

54-
_allowed_tensor_types: ClassVar = {
55-
ir.DataType.FLOAT,
56-
ir.DataType.FLOAT16,
57-
ir.DataType.BFLOAT16,
58-
ir.DataType.DOUBLE,
59-
}
54+
# Simplify "cast type1 => type2 => type3" to "cast type1 => type3".
55+
# This rule is not valid for all combinations of types: e.g.,
56+
# it is not valid for float32 => float16 => float32 or float32 => int32 => string.
57+
# TODO: fill out the list of allowed combinations: the following is just a couple
58+
# that shows up in practice where it is valid
59+
_allowed_type2_type3: ClassVar = frozenset(
60+
{
61+
(ir.DataType.FLOAT, ir.DataType.FLOAT16),
62+
(ir.DataType.FLOAT, ir.DataType.BFLOAT16),
63+
}
64+
)
6065

6166
def pattern(self, op, x, to, to_ignored):
6267
return op.Cast(op.Cast(x, to=to_ignored), to=to)
6368

6469
def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
6570
check_result = orp.MatchResult()
66-
if to.as_int() not in self._allowed_tensor_types:
67-
return check_result.fail(f"Output type {to.as_int()} is not allowed")
68-
if to_ignored.as_int() not in self._allowed_tensor_types:
69-
return check_result.fail(f"Ignored type {to_ignored.as_int()} is not allowed")
71+
type2 = to_ignored.as_int()
72+
type3 = to.as_int()
73+
if (type2, type3) not in self._allowed_type2_type3:
74+
return check_result.fail(
75+
f"Intermediate cast elimination not recognized as valid from {type2} to {type3}. "
76+
f"Cast-Cast rule may be incomplete for this combination."
77+
)
7078
return check_result
7179

7280
def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
@@ -284,7 +292,7 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet:
284292
"""
285293
return orp.RewriteRuleSet(
286294
[
287-
# cast_cast_rule, # Might have precision issues.
295+
cast_cast_rule,
288296
cast_identity_rule,
289297
expand_identity_rule,
290298
reshape_reshape_rule,

onnxscript/rewriter/llama_rule_sets_test.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -133,41 +133,36 @@ def test_llama_p0_rule_set_transpose_transpose(self, _: str, model: ir.Model):
133133
self.assertEqual(["Transpose"], [n.op_type for n in model.graph])
134134
self._check_model(model_proto, rewritten_model)
135135

136+
def _double_cast_model(self, ostype1, ostype2, ostype3):
137+
dtype2 = ostype2.dtype
138+
dtype3 = ostype3.dtype
139+
140+
@onnxscript.script()
141+
def cast_cast_model(x):
142+
intermediate = opset18.Cast(x, to=dtype2)
143+
y = opset18.Cast(intermediate, to=dtype3)
144+
return y
145+
146+
return cast_cast_model.to_model_proto(
147+
input_types=[ostype1[10]], output_types=[ostype3[10]]
148+
)
149+
136150
@parameterized.parameterized.expand(
137151
[
138-
(
139-
"double_casts",
140-
_make_model(
141-
onnx.helper.make_graph(
142-
[
143-
onnx.helper.make_node(
144-
"Cast", ["X"], ["Xc"], to=onnx.TensorProto.FLOAT16
145-
),
146-
onnx.helper.make_node(
147-
"Cast", ["Xc"], ["Y"], to=onnx.TensorProto.DOUBLE
148-
),
149-
],
150-
"name",
151-
[onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])],
152-
[
153-
onnx.helper.make_tensor_value_info(
154-
"Y", onnx.TensorProto.DOUBLE, [None, None, None]
155-
)
156-
],
157-
),
158-
opset_imports=[onnx.helper.make_opsetid("", 18)],
159-
),
160-
),
152+
("float16_float_float16", ot.FLOAT16, ot.FLOAT, ot.FLOAT16),
161153
]
162154
)
163-
def test_llama_p0_rule_set_cast_cast(self, _: str, model: ir.Model):
155+
def test_llama_p0_rule_set_cast_cast(self, _: str, type1, type2, type3):
164156
rule_set = llama_rule_sets.cast_cast_rule
165-
model_proto = ir.serde.serialize_model(model)
157+
model_proto = self._double_cast_model(type1, type2, type3)
158+
model = ir.serde.deserialize_model(model_proto)
166159
rule_set.apply_to_model(model)
167160
rewritten_model = ir.serde.serialize_model(model)
168161

169162
self.assertEqual(["Cast"], [n.op_type for n in model.graph])
170-
self._check_model(model_proto, rewritten_model, atol=1e-2)
163+
# TODO: (random) fp16 inputs
164+
# self._check_model(model_proto, rewritten_model, atol=1e-2)
165+
del rewritten_model # to avoid unused variable warning
171166

172167
@parameterized.parameterized.expand(
173168
[

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def fuse_xformers(model: ir.Model, debug: bool = False) -> tuple[ir.Model, dict[
7474

7575
model = _pre_optimize(model)
7676

77-
def fuse(func, apply_shape_inference: bool = False):
78-
return func(model, debug=debug, apply_shape_inference=apply_shape_inference)
77+
def fuse(func, **kwargs):
78+
return func(model, debug=debug, **kwargs)
7979

8080
fusion_count["erf_gelu"] = fuse(fuse_erfgelu)
8181
fusion_count["rms_normalization"] = fuse(fuse_rms_normalization)

0 commit comments

Comments
 (0)