Skip to content

Commit ccfe90d

Browse files
committed
Update base for Update on "[Experimental] Convert CastLike to Cast when dtype is available | feat(torchlib)"
[ghstack-poisoned]
2 parents f97b719 + 77ef131 commit ccfe90d

File tree

3 files changed

+166
-19
lines changed

3 files changed

+166
-19
lines changed

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 153 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,14 @@ def _rename_intermediate_value(name: str) -> str:
9090
class TorchScriptTensor(onnxscript_tensor.Tensor):
9191
"""A onnxscript tensor that wraps a torchscript Value."""
9292

93-
def __init__(self, value: torch.Value):
93+
def __init__(
94+
self,
95+
value: torch.Value,
96+
):
9497
super().__init__(None)
9598
self._torch_value: torch.Value = value
9699
self._concrete_value: Optional[np.ndarray] = None
97-
self._shape: Optional[Tuple[int | None, ...]] = None
100+
self._shape: Optional[Tuple[int | str | None, ...]] = None
98101
self._torch_dtype: Optional[torch.dtype] = None
99102
self._name: Optional[str] = None
100103
self._is_complex: bool = False
@@ -125,14 +128,17 @@ def name(self, name: str):
125128

126129
@property # type: ignore[override]
127130
def rank(self) -> int | None:
131+
if self._shape is not None:
132+
return len(self._shape)
133+
128134
value_type = self._torch_value.type()
129135
if value_type is None:
130136
return None
131137
value_type = typing.cast(torch.TensorType, value_type)
132138
return value_type.dim()
133139

134140
@property # type: ignore[override]
135-
def shape(self) -> Tuple[int | None, ...] | None:
141+
def shape(self) -> Tuple[int | str | None, ...] | None:
136142
if self._shape is not None:
137143
return self._shape
138144

@@ -149,9 +155,17 @@ def shape(self) -> Tuple[int | None, ...] | None:
149155
return tuple(shape)
150156

151157
@shape.setter
152-
def shape(self, shape: Tuple[int | None, ...]):
153-
self._shape = shape
154-
self._torch_value.setType(self._torch_value.type().with_sizes(list(shape)))
158+
def shape(self, shape: Union[torch.Size, Tuple[int | str | None, ...]]):
159+
# Normalize torch symbolic dimension size to str.
160+
torch_sym_types = (torch.SymInt, torch.SymFloat, torch.SymBool)
161+
self._shape = tuple(
162+
str(dim.node) if isinstance(dim, torch_sym_types) else dim # type: ignore[union-attr]
163+
for dim in shape
164+
)
165+
# jit api does not support assigning symbolic shapes,
166+
# hence symbols are replaced as None.
167+
jit_shape = tuple(dim if isinstance(dim, int) else None for dim in shape)
168+
self._torch_value.setType(self._torch_value.type().with_sizes(list(jit_shape)))
155169

156170
@property # type: ignore[override]
157171
def dtype(self) -> torch.dtype | None:
@@ -195,6 +209,15 @@ def symbolic_value(self) -> torch.Value:
195209
"""The symbolic Value in torch.Graph."""
196210
return self._torch_value
197211

212+
def value_info(self) -> Optional[onnx.ValueInfoProto]:
213+
try:
214+
dtype = self.onnx_dtype.value
215+
except torch.onnx.errors.OnnxExporterError:
216+
return None
217+
if dtype == onnx.TensorProto.UNDEFINED:
218+
return None
219+
return onnx.helper.make_tensor_value_info(self.name, dtype, self.shape)
220+
198221

199222
@runtime_typing.checked
200223
def _unwrap_tensor_to_torch_value(
@@ -223,7 +246,12 @@ def _unwrap_tensor_to_torch_value(
223246

224247
@runtime_typing.checked
225248
def _wrap_torch_value_to_tensor(
226-
value: Union[torch.Value, Mapping[str, ValidTorchValueType], Sequence[ValidTorchValueType]]
249+
value: Union[
250+
torch.Value, Mapping[str, ValidTorchValueType], Sequence[ValidTorchValueType]
251+
],
252+
*,
253+
shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
254+
dtype: Optional[torch.dtype] = None,
227255
) -> Union[
228256
ValidArgumentType,
229257
Dict[str, ValidArgumentType],
@@ -232,7 +260,12 @@ def _wrap_torch_value_to_tensor(
232260
]:
233261
"""Wrap torch.Value to TorchScriptTensor."""
234262
if isinstance(value, torch.Value):
235-
return TorchScriptTensor(value)
263+
tensor = TorchScriptTensor(value)
264+
if shape is not None:
265+
tensor.shape = shape
266+
if dtype is not None:
267+
tensor.dtype = dtype
268+
return tensor
236269
if isinstance(value, dict):
237270
return {k: _wrap_torch_value_to_tensor(v) for k, v in value.items()} # type: ignore[misc,return-value]
238271
if isinstance(value, list):
@@ -444,6 +477,16 @@ def __init__(
444477
self._parent_torch_script_graph = parent_torch_script_graph
445478
# Domain name of the graph. None if this is the top level graph.
446479
self._domain_name: Optional[str] = domain_name
480+
# Mapping from `torch.Value` to `TorchScriptTensor`.
481+
# Because `torch.Value` does not provide API to set and retrieve symbolic shapes,
482+
# and because `TorchScriptTensor` is not accessible through the `torch.Graph` graph,
483+
# this mapping is used to keep track of the `TorchScriptTensor` associated with
484+
# `torch.Value`.
485+
# `TorchScriptTensor` records dtype and symbolic shapes.
486+
# This info is later serialized as `ValueInfoProto` inside ONNX, to
487+
# provide shape and dtype information for nodes within nested function calls.
488+
# https://github.com/onnx/onnx/issues/5487
489+
self._value_to_tensor: Dict[torch.Value, TorchScriptTensor] = {}
447490

448491
if self._domain_name is None and self._parent_torch_script_graph is not None:
449492
raise RuntimeError(
@@ -486,7 +529,7 @@ def domain_name(self) -> Optional[str]:
486529
def add_input(
487530
self,
488531
input_name: Optional[str],
489-
shape: Optional[Union[torch.Size, Sequence[Union[int, str, None]]]] = None,
532+
shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
490533
dtype: Optional[torch.dtype] = None,
491534
) -> TorchScriptTensor:
492535
if input_name is None:
@@ -507,7 +550,11 @@ def add_input(
507550
[dim if isinstance(dim, int) else None for dim in shape] # type: ignore[union-attr]
508551
)
509552
)
510-
tensor_value = _wrap_torch_value_to_tensor(torch_value)
553+
tensor_value = _wrap_torch_value_to_tensor(torch_value, shape=shape, dtype=dtype)
554+
if isinstance(tensor_value, TorchScriptTensor):
555+
# NOTE: Only track value that maps to tensor.
556+
# Value that maps to Sequence/Dict of tensors is not tracked.
557+
self._value_to_tensor[torch_value] = tensor_value
511558
return tensor_value # type: ignore[return-value]
512559

513560
@runtime_typing.checked
@@ -531,16 +578,16 @@ def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor:
531578
self._initializers_inputs_from_parent[
532579
name
533580
] = self._parent_torch_script_graph.add_initializer(name, value)
534-
torch_value = self._torch_graph.addInput(name)
535-
torch_value.setType(torch.TensorType.create_from_tensor(value))
536-
tensor_value = _wrap_torch_value_to_tensor(torch_value)
537-
self._initializers_inputs[name] = tensor_value # type: ignore[assignment]
538-
return tensor_value # type: ignore[return-value]
581+
else:
582+
self._initializers[name] = value
539583

540-
self._initializers[name] = value
541584
torch_value = self._torch_graph.addInput(name)
542585
torch_value.setType(torch.TensorType.create_from_tensor(value))
543-
tensor_value = _wrap_torch_value_to_tensor(torch_value)
586+
tensor_value = _wrap_torch_value_to_tensor(
587+
torch_value, shape=value.shape, dtype=value.dtype
588+
)
589+
if isinstance(tensor_value, TorchScriptTensor):
590+
self._value_to_tensor[torch_value] = tensor_value
544591
self._initializers_inputs[name] = tensor_value # type: ignore[assignment]
545592
return tensor_value # type: ignore[return-value]
546593

@@ -640,11 +687,16 @@ def _add_torchscript_op_call(
640687
n_outputs=n_outputs,
641688
)
642689
assert result, "Expected at least one output from ONNX op call."
690+
# NOTE: TorchScriptTensor is created here, however neither dtype nor shape is
691+
# set. It is expected that exporter will modify the tensor being returned and
692+
# set these info.
643693
if len(result) == 1:
644694
tensor = TorchScriptTensor(result[0])
645695
tensor.name = _rename_intermediate_value(tensor.name)
696+
self._value_to_tensor[result[0]] = tensor
646697
return tensor
647698
tensors = tuple(TorchScriptTensor(v) for v in result)
699+
self._value_to_tensor.update(dict(zip(result, tensors)))
648700
for tensor in tensors:
649701
tensor.name = _rename_intermediate_value(tensor.name)
650702
return tensors
@@ -679,6 +731,54 @@ def fetch_function_proto_dict(
679731
function_proto_dict[name_domain] = function.to_function_proto()
680732
return function_proto_dict
681733

734+
@runtime_typing.checked
735+
def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto):
736+
existing_value_info = {info.name: info for info in onnx_model.graph.value_info}
737+
738+
# Override value_info for top level graph inputs.
739+
for input in self.torch_graph.inputs():
740+
if input not in self._value_to_tensor:
741+
raise RuntimeError(f"Input '{input.debugName()}' has no type.")
742+
tensor = self._value_to_tensor[input]
743+
if (value_info := tensor.value_info()) is None:
744+
continue
745+
for i, input_info in enumerate(onnx_model.graph.input):
746+
if input_info.name == input.debugName():
747+
onnx_model.graph.input.insert(i, value_info)
748+
onnx_model.graph.input.remove(input_info)
749+
break
750+
751+
# Override value_info for top level graph outputs.
752+
for output in self.torch_graph.outputs():
753+
if output not in self._value_to_tensor:
754+
raise RuntimeError(f"Output '{output.debugName()}' has no type.")
755+
tensor = self._value_to_tensor[output]
756+
if (value_info := tensor.value_info()) is None:
757+
continue
758+
for i, output_info in enumerate(onnx_model.graph.output):
759+
if output_info.name == output.debugName():
760+
onnx_model.graph.output.insert(i, value_info)
761+
onnx_model.graph.output.remove(output_info)
762+
break
763+
764+
# Remove existing static/incomplete value info.
765+
del onnx_model.graph.value_info[:]
766+
767+
# Insert value info for nodes within nested function calls.
768+
# NOTE: This is an experimental feature, since in official ONNX spec, nodes
769+
# within FunctionProto to have value info. https://github.com/onnx/onnx/issues/5487
770+
# The names for value info are generated uniquely to be retrievable based on
771+
# the call site and call stack.
772+
# The naming strategy is subject to change. Since all local functions representing
773+
# nn.Modules exported by dynamo exporter have unique call sites, their function
774+
# op_type name can serve to form the unique identifier for value info.
775+
function_value_infos = self.generate_function_value_info_proto()
776+
# Override existing value info for nodes in top level graph.
777+
existing_value_info.update(function_value_infos)
778+
onnx_model.graph.value_info.extend(existing_value_info.values())
779+
780+
return onnx_model
781+
682782
@runtime_typing.checked
683783
def add_op_call(
684784
self,
@@ -737,6 +837,39 @@ def add_module_call(
737837
n_outputs=sub_torch_script_graph.num_outputs,
738838
)
739839

840+
@runtime_typing.checked
841+
def generate_function_value_info_proto(
842+
self, prefix: str = ""
843+
) -> Mapping[str, onnx.ValueInfoProto]:
844+
"""Unique naming strategies
845+
846+
{function1_op_type}/{function2_op_type}/.../{value_name}
847+
848+
As long as function op_type has unique call site, this is safe.
849+
850+
Preferably, the following is better
851+
852+
{node1_name}/{node2_name}/.../{value_name}
853+
854+
However, node name is an optional field generated on the fly during torchscript
855+
graph serialization to onnx model proto. Such info is not retrievable at this point.
856+
"""
857+
named_value_info = {}
858+
for torch_value, tensor in self._value_to_tensor.items():
859+
name = torch_value.debugName()
860+
if (value_info := tensor.value_info()) is None:
861+
continue
862+
if prefix:
863+
name = f"{prefix}/{name}"
864+
named_value_info[name] = value_info
865+
for name, sub_graph in self._sub_torch_script_graphs.items():
866+
named_value_info.update(
867+
sub_graph.generate_function_value_info_proto(
868+
f"{prefix}/{name}" if prefix else name
869+
)
870+
)
871+
return named_value_info
872+
740873
@runtime_typing.checked
741874
def to_function_proto(self, opset_version: int, function_name: str) -> onnx.FunctionProto:
742875
assert len(self.initializers) == 0, "Model local functions cannot have initializers."
@@ -846,6 +979,9 @@ def to_model_proto(
846979
onnx_model.functions.extend(function_proto_dict.values())
847980
onnx_model.functions.extend(_shared_functions())
848981

982+
# Override value_infos with symbolic shapes.
983+
onnx_model = self._override_with_symbolic_value_info_proto(onnx_model)
984+
849985
# `_export_onnx` only exports opset_imports that is visible to it. It does not
850986
# export opset_imports for nested functions, since it does not have access to
851987
# them. We manually add them back and merge with existing opset_imports in the

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ name = "onnxscript"
77
dynamic = ["version"]
88
description = "Naturally author ONNX functions and models using a subset of Python"
99
authors = [{ name = "Microsoft Corporation", email = "[email protected]" }]
10-
urls = { "Repository" = "https://github.com/microsoft/onnxscript" }
1110
readme = "README.md"
1211
requires-python = ">=3.8"
1312
license = { file = "LICENSE" }

setup.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import datetime
88
import os
99
import pathlib
10+
import subprocess
1011

1112
import setuptools
1213

@@ -15,9 +16,20 @@
1516
VERSION_FILE = ROOT_DIR / "VERSION"
1617
version = VERSION_FILE.read_text().strip()
1718

19+
project_urls = {
20+
"Repository": "https://github.com/microsoft/onnxscript",
21+
}
1822
if os.environ.get("ONNX_SCRIPT_RELEASE") != "1":
1923
date = datetime.date.today().strftime("%Y%m%d")
2024
version = f"{version}.dev{date}"
2125

26+
commit_hash_cmd = subprocess.run(
27+
["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, check=False
28+
)
29+
if commit_hash_cmd.returncode == 0:
30+
project_urls[
31+
"Commit"
32+
] = f"https://github.com/microsoft/onnxscript/tree/{commit_hash_cmd.stdout.decode('utf-8').strip()}"
33+
2234
# NOTE: Do not include other metadata in setup.py. Put it in pyproject.toml.
23-
setuptools.setup(version=version)
35+
setuptools.setup(version=version, project_urls=project_urls, url="https://onnxscript.ai/")

0 commit comments

Comments
 (0)