Skip to content

Commit 5da12ed

Browse files
committed
Rename intermediate values to _val and _const | feat(torchlib)
ghstack-source-id: 54bb9d9 Pull Request resolved: #881
1 parent 06a0a5c commit 5da12ed

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@
7676
# TODO(justinchuby): Build a context manager to handle source information.
7777

7878

79+
def _rename_intermediate_value(name: str) -> str:
80+
if name.isdigit():
81+
return f"_val_{name}"
82+
return name
83+
84+
85+
def _rename_intermediate_constant(name: str) -> str:
86+
if name.isdigit():
87+
return f"_const_{name}"
88+
return name
89+
90+
7991
class TorchScriptTensor(onnxscript_tensor.Tensor):
8092
"""A onnxscript tensor that wraps a torchscript Value."""
8193

@@ -454,6 +466,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
454466
self._torch_graph, "prim::Constant", inputs=(), attributes={}
455467
)[0]
456468
value.setType(torch.OptionalType.ofTensor())
469+
value.setDebugName(_rename_intermediate_constant(value.debugName()))
457470
return value
458471

459472
if isinstance(constant, bool):
@@ -475,12 +488,14 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
475488
raise TypeError(
476489
f"Constant input '{constant}' of type '{type(constant)}' is not supported"
477490
)
478-
return _create_op_call_in_torch_graph(
491+
value = _create_op_call_in_torch_graph(
479492
self._torch_graph,
480493
"onnx::Constant",
481494
inputs=(),
482495
attributes=dict(value=constant_tensor),
483496
)[0]
497+
value.setDebugName(_rename_intermediate_constant(value.debugName()))
498+
return value
484499

485500
@runtime_typing.checked
486501
def _add_torchscript_op_call(
@@ -524,9 +539,15 @@ def _add_torchscript_op_call(
524539
attributes=onnx_attributes,
525540
n_outputs=n_outputs,
526541
)
527-
if len(result) <= 1:
528-
return TorchScriptTensor(result[0])
529-
return tuple(TorchScriptTensor(v) for v in result)
542+
assert result, "Expected at least one output from ONNX op call."
543+
if len(result) == 1:
544+
tensor = TorchScriptTensor(result[0])
545+
tensor.name = _rename_intermediate_value(tensor.name)
546+
return tensor
547+
tensors = tuple(TorchScriptTensor(v) for v in result)
548+
for tensor in tensors:
549+
tensor.name = _rename_intermediate_value(tensor.name)
550+
return tensors
530551

531552
@runtime_typing.checked
532553
def fetch_function_proto_dict(

0 commit comments

Comments
 (0)