Skip to content

Commit bafe515

Browse files
committed
Convert CastLike to Cast when dtype is available | feat(torchlib)
ghstack-source-id: d4b13d6 Pull Request resolved: #1179 Signed-off-by: Justin Chu <[email protected]>
1 parent 7051dc6 commit bafe515

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,25 @@ def eval(self, schema, inputs, attributes):
265265
if schema.name == "CastLike":
266266
assert len(inputs) == 2
267267
# Skip CastLike if the input and output types are the same
268-
if (
269-
inputs[0] is not None
270-
and inputs[1] is not None
271-
and inputs[0].dtype == inputs[1].dtype
272-
and inputs[1].dtype is not None
273-
):
274-
return inputs[0]
268+
src_input = inputs[0]
269+
target_input = inputs[1]
270+
dtypes_available = (
271+
isinstance(src_input, TorchScriptTensor)
272+
and isinstance(target_input, TorchScriptTensor)
273+
and src_input.dtype is not None
274+
and target_input.dtype is not None
275+
)
276+
if dtypes_available:
277+
if src_input.dtype == target_input.dtype:
278+
# Same type. No cast needed
279+
return src_input
280+
else:
281+
# Create a Cast node
282+
return self._graph.add_op_call(
283+
onnx.defs.get_schema("Cast"),
284+
(src_input,),
285+
{"to": target_input.onnx_dtype},
286+
)
275287
return self._graph.add_op_call(schema, inputs, attributes)
276288

277289
@runtime_typing.checked

0 commit comments

Comments
 (0)