Skip to content

Commit 9c7de8a

Browse files
authored
Rename intermediate values to _val and _const | feat(torchlib) (#881)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #883 * #882 * __->__ #881 The torchscript ONNX graph generator creates numeric value names by default (`0`, `1`). These are not legal ONNX tensor names, since ONNX requires the names to be valid C variable names. This change updates the names by prepending a prefix `_val_` or `_const_` to make them valid ONNX names. It also improves readability by making the names less likely to be confused with shape values. I decided to use the `_` prefix to reduce the chance of name collision with FX names. After: ``` < ir_version: 8, opset_import: ["" : 18], producer_name: "pytorch", producer_version: "2.1.0" > torch_jit (float[5,5,5,5] input_0, int64[2] input_1_3) => (float[5,5,5,2] _val_10) { _val_2 = Transpose <perm = [0, 1, 2, 3]> (input_0) _val_3 = Max (input_1_3) _val_4 = Shape <start = 0> (_val_3) _val_5 = Expand (input_1_3, _val_4) _const_6 = Constant <value = int64 {-1}> () _val_7 = Unsqueeze (_val_5, _const_6) _val_8 = Concat <axis = -1> (_val_7) _val_9 = GatherND <batch_dims = 0> (_val_2, _val_8) _val_10 = Transpose <perm = [0, 1, 2, 3]> (_val_9) } ``` Before: ``` < ir_version: 8, opset_import: ["" : 18], producer_name: "pytorch", producer_version: "2.1.0" > torch_jit (float[5,5,5,5] input_0, int64[2] input_1_3) => (float[5,5,5,2] 10) { 2 = Transpose <perm = [0, 1, 2, 3]> (input_0) 3 = Max (input_1_3) 4 = Shape <start = 0> (3) 5 = Expand (input_1_3, 4) 6 = Constant <value = int64 {-1}> () 7 = Unsqueeze (5, 6) 8 = Concat <axis = -1> (7) 9 = GatherND <batch_dims = 0> (2, 8) 10 = Transpose <perm = [0, 1, 2, 3]> (9) } ```
1 parent 06a0a5c commit 9c7de8a

1 file changed

Lines changed: 25 additions & 4 deletions

File tree

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@
7676
# TODO(justinchuby): Build a context manager to handle source information.
7777

7878

79+
def _rename_intermediate_value(name: str) -> str:
80+
if name.isdigit():
81+
return f"_val_{name}"
82+
return name
83+
84+
85+
def _rename_intermediate_constant(name: str) -> str:
86+
if name.isdigit():
87+
return f"_const_{name}"
88+
return name
89+
90+
7991
class TorchScriptTensor(onnxscript_tensor.Tensor):
8092
"""A onnxscript tensor that wraps a torchscript Value."""
8193

@@ -454,6 +466,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
454466
self._torch_graph, "prim::Constant", inputs=(), attributes={}
455467
)[0]
456468
value.setType(torch.OptionalType.ofTensor())
469+
value.setDebugName(_rename_intermediate_constant(value.debugName()))
457470
return value
458471

459472
if isinstance(constant, bool):
@@ -475,12 +488,14 @@ def _add_constant_to_graph(self, constant) -> torch.Value:
475488
raise TypeError(
476489
f"Constant input '{constant}' of type '{type(constant)}' is not supported"
477490
)
478-
return _create_op_call_in_torch_graph(
491+
value = _create_op_call_in_torch_graph(
479492
self._torch_graph,
480493
"onnx::Constant",
481494
inputs=(),
482495
attributes=dict(value=constant_tensor),
483496
)[0]
497+
value.setDebugName(_rename_intermediate_constant(value.debugName()))
498+
return value
484499

485500
@runtime_typing.checked
486501
def _add_torchscript_op_call(
@@ -524,9 +539,15 @@ def _add_torchscript_op_call(
524539
attributes=onnx_attributes,
525540
n_outputs=n_outputs,
526541
)
527-
if len(result) <= 1:
528-
return TorchScriptTensor(result[0])
529-
return tuple(TorchScriptTensor(v) for v in result)
542+
assert result, "Expected at least one output from ONNX op call."
543+
if len(result) == 1:
544+
tensor = TorchScriptTensor(result[0])
545+
tensor.name = _rename_intermediate_value(tensor.name)
546+
return tensor
547+
tensors = tuple(TorchScriptTensor(v) for v in result)
548+
for tensor in tensors:
549+
tensor.name = _rename_intermediate_value(tensor.name)
550+
return tensors
530551

531552
@runtime_typing.checked
532553
def fetch_function_proto_dict(

0 commit comments

Comments
 (0)