diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index 5eb1c56c4f..bc88237be0 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -76,6 +76,18 @@ # TODO(justinchuby): Build a context manager to handle source information. +def _rename_intermediate_value(name: str) -> str: + if name.isdigit(): + return f"_val_{name}" + return name + + +def _rename_intermediate_constant(name: str) -> str: + if name.isdigit(): + return f"_const_{name}" + return name + + class TorchScriptTensor(onnxscript_tensor.Tensor): """A onnxscript tensor that wraps a torchscript Value.""" @@ -454,6 +466,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value: self._torch_graph, "prim::Constant", inputs=(), attributes={} )[0] value.setType(torch.OptionalType.ofTensor()) + value.setDebugName(_rename_intermediate_constant(value.debugName())) return value if isinstance(constant, bool): @@ -475,12 +488,14 @@ def _add_constant_to_graph(self, constant) -> torch.Value: raise TypeError( f"Constant input '{constant}' of type '{type(constant)}' is not supported" ) - return _create_op_call_in_torch_graph( + value = _create_op_call_in_torch_graph( self._torch_graph, "onnx::Constant", inputs=(), attributes=dict(value=constant_tensor), )[0] + value.setDebugName(_rename_intermediate_constant(value.debugName())) + return value @runtime_typing.checked def _add_torchscript_op_call( @@ -524,9 +539,15 @@ def _add_torchscript_op_call( attributes=onnx_attributes, n_outputs=n_outputs, ) - if len(result) <= 1: - return TorchScriptTensor(result[0]) - return tuple(TorchScriptTensor(v) for v in result) + assert result, "Expected at least one output from ONNX op call." + if len(result) == 1: + tensor = TorchScriptTensor(result[0]) + tensor.name = _rename_intermediate_value(tensor.name) + return tensor + tensors = tuple(TorchScriptTensor(v) for v in result) + for tensor in tensors: + tensor.name = _rename_intermediate_value(tensor.name) + return tensors @runtime_typing.checked def fetch_function_proto_dict( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index 21d18380e7..5739f347e9 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -16,7 +16,7 @@ from __future__ import annotations import unittest -from typing import Any, Callable, Optional, Sequence, Tuple +from typing import Callable, Optional, Sequence, Tuple import numpy as np import onnx @@ -72,18 +72,6 @@ def _should_skip_xfail_test_sample( return None, None -def _split_function_and_wrangler( - onnx_function_and_wrangler: Callable[..., Any] - | tuple[Callable[..., Any], Callable[..., Any]] -) -> tuple[Callable[..., Any], Callable[..., Any] | None]: - """Splits a function with an optional input wrangler into a function and an input wrangler.""" - if isinstance(onnx_function_and_wrangler, tuple): - return onnx_function_and_wrangler - - assert callable(onnx_function_and_wrangler) - return onnx_function_and_wrangler, None - - class TestFunctionValidity(unittest.TestCase): def test_all_script_functions_are_onnx_functions(self): for info in ops_test_data.TESTED_TORCHLIB_OPS: diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py index fae2f709ad..63c4bebaca 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_common.py @@ -517,9 +517,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, onnx.checker.check_model(onnx_model, full_check=True) except onnx.checker.ValidationError as e: raise AssertionError( - f"ONNX model is invalid: {e}. " - f"Model:\n" - f"{onnxscript.proto2text(onnx_model)}" + f"ONNX model is invalid. Model:\n{onnxscript.proto2text(onnx_model)}" ) from e try: