Skip to content

[Experimental] Convert CastLike to Cast when dtype is available | feat(torchlib) #1179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
26 changes: 19 additions & 7 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,25 @@ def eval(self, schema, inputs, attributes):
if schema.name == "CastLike":
assert len(inputs) == 2
# Skip CastLike if the input and output types are the same
if (
inputs[0] is not None
and inputs[1] is not None
and inputs[0].dtype == inputs[1].dtype
and inputs[1].dtype is not None
):
return inputs[0]
src_input = inputs[0]
target_input = inputs[1]
dtypes_available = (
isinstance(src_input, TorchScriptTensor)
and isinstance(target_input, TorchScriptTensor)
and src_input.dtype is not None
and target_input.dtype is not None
)
if dtypes_available:
if src_input.dtype == target_input.dtype:
# Same type. No cast needed
return src_input
else:
# Create a Cast node
return self._graph.add_op_call(
onnx.defs.get_schema("Cast"),
(src_input,),
{"to": target_input.onnx_dtype},
)
return self._graph.add_op_call(schema, inputs, attributes)

@runtime_typing.checked
Expand Down