Skip to content

Commit 00ea75e

Browse files
authored
[Experimental] Export with symbolic shapes (#1172)
Experimental feature to store symbolic shapes produced by torch dynamo inside the exported onnx model. There is no official ONNX spec to support nodes within FunctionProto to have value info, onnx/onnx#5487. The names for value info are generated uniquely to be retrievable based on the call site and call stack.
1 parent b7f215e commit 00ea75e

File tree

1 file changed

+153
-17
lines changed

1 file changed

+153
-17
lines changed

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 153 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,14 @@ def _rename_intermediate_value(name: str) -> str:
9090
class TorchScriptTensor(onnxscript_tensor.Tensor):
9191
"""A onnxscript tensor that wraps a torchscript Value."""
9292

93-
def __init__(self, value: torch.Value):
93+
def __init__(
94+
self,
95+
value: torch.Value,
96+
):
9497
super().__init__(None)
9598
self._torch_value: torch.Value = value
9699
self._concrete_value: Optional[np.ndarray] = None
97-
self._shape: Optional[Tuple[int | None, ...]] = None
100+
self._shape: Optional[Tuple[int | str | None, ...]] = None
98101
self._torch_dtype: Optional[torch.dtype] = None
99102
self._name: Optional[str] = None
100103
self._is_complex: bool = False
@@ -125,14 +128,17 @@ def name(self, name: str):
125128

126129
@property # type: ignore[override]
127130
def rank(self) -> int | None:
131+
if self._shape is not None:
132+
return len(self._shape)
133+
128134
value_type = self._torch_value.type()
129135
if value_type is None:
130136
return None
131137
value_type = typing.cast(torch.TensorType, value_type)
132138
return value_type.dim()
133139

134140
@property # type: ignore[override]
135-
def shape(self) -> Tuple[int | None, ...] | None:
141+
def shape(self) -> Tuple[int | str | None, ...] | None:
136142
if self._shape is not None:
137143
return self._shape
138144

@@ -149,9 +155,17 @@ def shape(self) -> Tuple[int | None, ...] | None:
149155
return tuple(shape)
150156

151157
@shape.setter
152-
def shape(self, shape: Tuple[int | None, ...]):
153-
self._shape = shape
154-
self._torch_value.setType(self._torch_value.type().with_sizes(list(shape)))
158+
def shape(self, shape: Union[torch.Size, Tuple[int | str | None, ...]]):
159+
# Normalize torch symbolic dimension size to str.
160+
torch_sym_types = (torch.SymInt, torch.SymFloat, torch.SymBool)
161+
self._shape = tuple(
162+
str(dim.node) if isinstance(dim, torch_sym_types) else dim # type: ignore[union-attr]
163+
for dim in shape
164+
)
165+
# jit api does not support assigning symbolic shapes,
166+
# hence symbols are replaced as None.
167+
jit_shape = tuple(dim if isinstance(dim, int) else None for dim in shape)
168+
self._torch_value.setType(self._torch_value.type().with_sizes(list(jit_shape)))
155169

156170
@property # type: ignore[override]
157171
def dtype(self) -> torch.dtype | None:
@@ -195,6 +209,15 @@ def symbolic_value(self) -> torch.Value:
195209
"""The symbolic Value in torch.Graph."""
196210
return self._torch_value
197211

212+
def value_info(self) -> Optional[onnx.ValueInfoProto]:
213+
try:
214+
dtype = self.onnx_dtype.value
215+
except torch.onnx.errors.OnnxExporterError:
216+
return None
217+
if dtype == onnx.TensorProto.UNDEFINED:
218+
return None
219+
return onnx.helper.make_tensor_value_info(self.name, dtype, self.shape)
220+
198221

199222
@runtime_typing.checked
200223
def _unwrap_tensor_to_torch_value(
@@ -223,7 +246,12 @@ def _unwrap_tensor_to_torch_value(
223246

224247
@runtime_typing.checked
225248
def _wrap_torch_value_to_tensor(
226-
value: Union[torch.Value, Mapping[str, ValidTorchValueType], Sequence[ValidTorchValueType]]
249+
value: Union[
250+
torch.Value, Mapping[str, ValidTorchValueType], Sequence[ValidTorchValueType]
251+
],
252+
*,
253+
shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
254+
dtype: Optional[torch.dtype] = None,
227255
) -> Union[
228256
ValidArgumentType,
229257
Dict[str, ValidArgumentType],
@@ -232,7 +260,12 @@ def _wrap_torch_value_to_tensor(
232260
]:
233261
"""Wrap torch.Value to TorchScriptTensor."""
234262
if isinstance(value, torch.Value):
235-
return TorchScriptTensor(value)
263+
tensor = TorchScriptTensor(value)
264+
if shape is not None:
265+
tensor.shape = shape
266+
if dtype is not None:
267+
tensor.dtype = dtype
268+
return tensor
236269
if isinstance(value, dict):
237270
return {k: _wrap_torch_value_to_tensor(v) for k, v in value.items()} # type: ignore[misc,return-value]
238271
if isinstance(value, list):
@@ -399,6 +432,16 @@ def __init__(
399432
self._parent_torch_script_graph = parent_torch_script_graph
400433
# Domain name of the graph. None if this is the top level graph.
401434
self._domain_name: Optional[str] = domain_name
435+
# Mapping from `torch.Value` to `TorchScriptTensor`.
436+
# Because `torch.Value` does not provide API to set and retrieve symbolic shapes,
437+
# and because `TorchScriptTensor` is not accessible through the `torch.Graph` graph,
438+
# this mapping is used to keep track of the `TorchScriptTensor` associated with
439+
# `torch.Value`.
440+
# `TorchScriptTensor` records dtype and symbolic shapes.
441+
# This info is later serialized as `ValueInfoProto` inside ONNX, to
442+
# provide shape and dtype information for nodes within nested function calls.
443+
# https://github.com/onnx/onnx/issues/5487
444+
self._value_to_tensor: Dict[torch.Value, TorchScriptTensor] = {}
402445

403446
if self._domain_name is None and self._parent_torch_script_graph is not None:
404447
raise RuntimeError(
@@ -441,7 +484,7 @@ def domain_name(self) -> Optional[str]:
441484
def add_input(
442485
self,
443486
input_name: Optional[str],
444-
shape: Optional[Union[torch.Size, Sequence[Union[int, str, None]]]] = None,
487+
shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
445488
dtype: Optional[torch.dtype] = None,
446489
) -> TorchScriptTensor:
447490
if input_name is None:
@@ -462,7 +505,11 @@ def add_input(
462505
[dim if isinstance(dim, int) else None for dim in shape] # type: ignore[union-attr]
463506
)
464507
)
465-
tensor_value = _wrap_torch_value_to_tensor(torch_value)
508+
tensor_value = _wrap_torch_value_to_tensor(torch_value, shape=shape, dtype=dtype)
509+
if isinstance(tensor_value, TorchScriptTensor):
510+
# NOTE: Only track value that maps to tensor.
511+
# Value that maps to Sequence/Dict of tensors is not tracked.
512+
self._value_to_tensor[torch_value] = tensor_value
466513
return tensor_value # type: ignore[return-value]
467514

468515
@runtime_typing.checked
@@ -486,16 +533,16 @@ def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor:
486533
self._initializers_inputs_from_parent[
487534
name
488535
] = self._parent_torch_script_graph.add_initializer(name, value)
489-
torch_value = self._torch_graph.addInput(name)
490-
torch_value.setType(torch.TensorType.create_from_tensor(value))
491-
tensor_value = _wrap_torch_value_to_tensor(torch_value)
492-
self._initializers_inputs[name] = tensor_value # type: ignore[assignment]
493-
return tensor_value # type: ignore[return-value]
536+
else:
537+
self._initializers[name] = value
494538

495-
self._initializers[name] = value
496539
torch_value = self._torch_graph.addInput(name)
497540
torch_value.setType(torch.TensorType.create_from_tensor(value))
498-
tensor_value = _wrap_torch_value_to_tensor(torch_value)
541+
tensor_value = _wrap_torch_value_to_tensor(
542+
torch_value, shape=value.shape, dtype=value.dtype
543+
)
544+
if isinstance(tensor_value, TorchScriptTensor):
545+
self._value_to_tensor[torch_value] = tensor_value
499546
self._initializers_inputs[name] = tensor_value # type: ignore[assignment]
500547
return tensor_value # type: ignore[return-value]
501548

@@ -595,11 +642,16 @@ def _add_torchscript_op_call(
595642
n_outputs=n_outputs,
596643
)
597644
assert result, "Expected at least one output from ONNX op call."
645+
# NOTE: TorchScriptTensor is created here, however neither dtype nor shape is
646+
# set. It is expected that exporter will modify the tensor being returned and
647+
# set these info.
598648
if len(result) == 1:
599649
tensor = TorchScriptTensor(result[0])
600650
tensor.name = _rename_intermediate_value(tensor.name)
651+
self._value_to_tensor[result[0]] = tensor
601652
return tensor
602653
tensors = tuple(TorchScriptTensor(v) for v in result)
654+
self._value_to_tensor.update(dict(zip(result, tensors)))
603655
for tensor in tensors:
604656
tensor.name = _rename_intermediate_value(tensor.name)
605657
return tensors
@@ -634,6 +686,54 @@ def fetch_function_proto_dict(
634686
function_proto_dict[name_domain] = function.to_function_proto()
635687
return function_proto_dict
636688

689+
@runtime_typing.checked
690+
def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto):
691+
existing_value_info = {info.name: info for info in onnx_model.graph.value_info}
692+
693+
# Override value_info for top level graph inputs.
694+
for input in self.torch_graph.inputs():
695+
if input not in self._value_to_tensor:
696+
raise RuntimeError(f"Input '{input.debugName()}' has no type.")
697+
tensor = self._value_to_tensor[input]
698+
if (value_info := tensor.value_info()) is None:
699+
continue
700+
for i, input_info in enumerate(onnx_model.graph.input):
701+
if input_info.name == input.debugName():
702+
onnx_model.graph.input.insert(i, value_info)
703+
onnx_model.graph.input.remove(input_info)
704+
break
705+
706+
# Override value_info for top level graph outputs.
707+
for output in self.torch_graph.outputs():
708+
if output not in self._value_to_tensor:
709+
raise RuntimeError(f"Output '{output.debugName()}' has no type.")
710+
tensor = self._value_to_tensor[output]
711+
if (value_info := tensor.value_info()) is None:
712+
continue
713+
for i, output_info in enumerate(onnx_model.graph.output):
714+
if output_info.name == output.debugName():
715+
onnx_model.graph.output.insert(i, value_info)
716+
onnx_model.graph.output.remove(output_info)
717+
break
718+
719+
# Remove existing static/incomplete value info.
720+
del onnx_model.graph.value_info[:]
721+
722+
# Insert value info for nodes within nested function calls.
723+
# NOTE: This is an experimental feature, since in official ONNX spec, nodes
724+
# within FunctionProto to have value info. https://github.com/onnx/onnx/issues/5487
725+
# The names for value info are generated uniquely to be retrievable based on
726+
# the call site and call stack.
727+
# The naming strategy is subject to change. Since all local functions representing
728+
# nn.Modules exported by dynamo exporter have unique call sites, their function
729+
# op_type name can serve to form the unique identifier for value info.
730+
function_value_infos = self.generate_function_value_info_proto()
731+
# Override existing value info for nodes in top level graph.
732+
existing_value_info.update(function_value_infos)
733+
onnx_model.graph.value_info.extend(existing_value_info.values())
734+
735+
return onnx_model
736+
637737
@runtime_typing.checked
638738
def add_op_call(
639739
self,
@@ -692,6 +792,39 @@ def add_module_call(
692792
n_outputs=sub_torch_script_graph.num_outputs,
693793
)
694794

795+
@runtime_typing.checked
796+
def generate_function_value_info_proto(
797+
self, prefix: str = ""
798+
) -> Mapping[str, onnx.ValueInfoProto]:
799+
"""Unique naming strategies
800+
801+
{function1_op_type}/{function2_op_type}/.../{value_name}
802+
803+
As long as function op_type has unique call site, this is safe.
804+
805+
Preferably, the following is better
806+
807+
{node1_name}/{node2_name}/.../{value_name}
808+
809+
However, node name is an optional field generated on the fly during torchscript
810+
graph serialization to onnx model proto. Such info is not retrievable at this point.
811+
"""
812+
named_value_info = {}
813+
for torch_value, tensor in self._value_to_tensor.items():
814+
name = torch_value.debugName()
815+
if (value_info := tensor.value_info()) is None:
816+
continue
817+
if prefix:
818+
name = f"{prefix}/{name}"
819+
named_value_info[name] = value_info
820+
for name, sub_graph in self._sub_torch_script_graphs.items():
821+
named_value_info.update(
822+
sub_graph.generate_function_value_info_proto(
823+
f"{prefix}/{name}" if prefix else name
824+
)
825+
)
826+
return named_value_info
827+
695828
@runtime_typing.checked
696829
def to_function_proto(self, opset_version: int, function_name: str) -> onnx.FunctionProto:
697830
assert len(self.initializers) == 0, "Model local functions cannot have initializers."
@@ -801,6 +934,9 @@ def to_model_proto(
801934
onnx_model.functions.extend(function_proto_dict.values())
802935
onnx_model.functions.extend(_shared_functions())
803936

937+
# Override value_infos with symbolic shapes.
938+
onnx_model = self._override_with_symbolic_value_info_proto(onnx_model)
939+
804940
# `_export_onnx` only exports opset_imports that is visible to it. It does not
805941
# export opset_imports for nested functions, since it does not have access to
806942
# them. We manually add them back and merge with existing opset_imports in the

0 commit comments

Comments
 (0)