76
76
# TODO(justinchuby): Build a context manager to handle source information.
77
77
78
78
79
+ def _rename_intermediate_value (name : str ) -> str :
80
+ if name .isdigit ():
81
+ return f"_val_{ name } "
82
+ return name
83
+
84
+
85
+ def _rename_intermediate_constant (name : str ) -> str :
86
+ if name .isdigit ():
87
+ return f"_const_{ name } "
88
+ return name
89
+
90
+
79
91
class TorchScriptTensor (onnxscript_tensor .Tensor ):
80
92
"""A onnxscript tensor that wraps a torchscript Value."""
81
93
@@ -454,6 +466,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
454
466
self ._torch_graph , "prim::Constant" , inputs = (), attributes = {}
455
467
)[0 ]
456
468
value .setType (torch .OptionalType .ofTensor ())
469
+ value .setDebugName (_rename_intermediate_constant (value .debugName ()))
457
470
return value
458
471
459
472
if isinstance (constant , bool ):
@@ -475,12 +488,14 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
475
488
raise TypeError (
476
489
f"Constant input '{ constant } ' of type '{ type (constant )} ' is not supported"
477
490
)
478
- return _create_op_call_in_torch_graph (
491
+ value = _create_op_call_in_torch_graph (
479
492
self ._torch_graph ,
480
493
"onnx::Constant" ,
481
494
inputs = (),
482
495
attributes = dict (value = constant_tensor ),
483
496
)[0 ]
497
+ value .setDebugName (_rename_intermediate_constant (value .debugName ()))
498
+ return value
484
499
485
500
@runtime_typing .checked
486
501
def _add_torchscript_op_call (
@@ -524,9 +539,15 @@ def _add_torchscript_op_call(
524
539
attributes = onnx_attributes ,
525
540
n_outputs = n_outputs ,
526
541
)
527
- if len (result ) <= 1 :
528
- return TorchScriptTensor (result [0 ])
529
- return tuple (TorchScriptTensor (v ) for v in result )
542
+ assert result , "Expected at least one output from ONNX op call."
543
+ if len (result ) == 1 :
544
+ tensor = TorchScriptTensor (result [0 ])
545
+ tensor .name = _rename_intermediate_value (tensor .name )
546
+ return tensor
547
+ tensors = tuple (TorchScriptTensor (v ) for v in result )
548
+ for tensor in tensors :
549
+ tensor .name = _rename_intermediate_value (tensor .name )
550
+ return tensors
530
551
531
552
@runtime_typing .checked
532
553
def fetch_function_proto_dict (
0 commit comments