7676# TODO(justinchuby): Build a context manager to handle source information.
7777
7878
79+ def _rename_intermediate_value (name : str ) -> str :
80+ if name .isdigit ():
81+ return f"_val_{ name } "
82+ return name
83+
84+
85+ def _rename_intermediate_constant (name : str ) -> str :
86+ if name .isdigit ():
87+ return f"_const_{ name } "
88+ return name
89+
90+
7991class TorchScriptTensor (onnxscript_tensor .Tensor ):
8092 """A onnxscript tensor that wraps a torchscript Value."""
8193
@@ -454,6 +466,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
454466 self ._torch_graph , "prim::Constant" , inputs = (), attributes = {}
455467 )[0 ]
456468 value .setType (torch .OptionalType .ofTensor ())
469+ value .setDebugName (_rename_intermediate_constant (value .debugName ()))
457470 return value
458471
459472 if isinstance (constant , bool ):
@@ -475,12 +488,14 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
475488 raise TypeError (
476489 f"Constant input '{ constant } ' of type '{ type (constant )} ' is not supported"
477490 )
478- return _create_op_call_in_torch_graph (
491+ value = _create_op_call_in_torch_graph (
479492 self ._torch_graph ,
480493 "onnx::Constant" ,
481494 inputs = (),
482495 attributes = dict (value = constant_tensor ),
483496 )[0 ]
497+ value .setDebugName (_rename_intermediate_constant (value .debugName ()))
498+ return value
484499
485500 @runtime_typing .checked
486501 def _add_torchscript_op_call (
@@ -524,9 +539,15 @@ def _add_torchscript_op_call(
524539 attributes = onnx_attributes ,
525540 n_outputs = n_outputs ,
526541 )
527- if len (result ) <= 1 :
528- return TorchScriptTensor (result [0 ])
529- return tuple (TorchScriptTensor (v ) for v in result )
542+ assert result , "Expected at least one output from ONNX op call."
543+ if len (result ) == 1 :
544+ tensor = TorchScriptTensor (result [0 ])
545+ tensor .name = _rename_intermediate_value (tensor .name )
546+ return tensor
547+ tensors = tuple (TorchScriptTensor (v ) for v in result )
548+ for tensor in tensors :
549+ tensor .name = _rename_intermediate_value (tensor .name )
550+ return tensors
530551
531552 @runtime_typing .checked
532553 def fetch_function_proto_dict (
0 commit comments