@@ -90,11 +90,14 @@ def _rename_intermediate_value(name: str) -> str:
9090class TorchScriptTensor (onnxscript_tensor .Tensor ):
9191 """A onnxscript tensor that wraps a torchscript Value."""
9292
93- def __init__ (self , value : torch .Value ):
93+ def __init__ (
94+ self ,
95+ value : torch .Value ,
96+ ):
9497 super ().__init__ (None )
9598 self ._torch_value : torch .Value = value
9699 self ._concrete_value : Optional [np .ndarray ] = None
97- self ._shape : Optional [Tuple [int | None , ...]] = None
100+ self ._shape : Optional [Tuple [int | str | None , ...]] = None
98101 self ._torch_dtype : Optional [torch .dtype ] = None
99102 self ._name : Optional [str ] = None
100103 self ._is_complex : bool = False
@@ -125,14 +128,17 @@ def name(self, name: str):
125128
126129 @property # type: ignore[override]
127130 def rank (self ) -> int | None :
131+ if self ._shape is not None :
132+ return len (self ._shape )
133+
128134 value_type = self ._torch_value .type ()
129135 if value_type is None :
130136 return None
131137 value_type = typing .cast (torch .TensorType , value_type )
132138 return value_type .dim ()
133139
134140 @property # type: ignore[override]
135- def shape (self ) -> Tuple [int | None , ...] | None :
141+ def shape (self ) -> Tuple [int | str | None , ...] | None :
136142 if self ._shape is not None :
137143 return self ._shape
138144
@@ -149,9 +155,17 @@ def shape(self) -> Tuple[int | None, ...] | None:
149155 return tuple (shape )
150156
151157 @shape .setter
152- def shape (self , shape : Tuple [int | None , ...]):
153- self ._shape = shape
154- self ._torch_value .setType (self ._torch_value .type ().with_sizes (list (shape )))
158+ def shape (self , shape : Union [torch .Size , Tuple [int | str | None , ...]]):
159+ # Normalize torch symbolic dimension size to str.
160+ torch_sym_types = (torch .SymInt , torch .SymFloat , torch .SymBool )
161+ self ._shape = tuple (
162+ str (dim .node ) if isinstance (dim , torch_sym_types ) else dim # type: ignore[union-attr]
163+ for dim in shape
164+ )
165+ # jit api does not support assigning symbolic shapes,
166+ # hence symbols are replaced as None.
167+ jit_shape = tuple (dim if isinstance (dim , int ) else None for dim in shape )
168+ self ._torch_value .setType (self ._torch_value .type ().with_sizes (list (jit_shape )))
155169
156170 @property # type: ignore[override]
157171 def dtype (self ) -> torch .dtype | None :
@@ -195,6 +209,15 @@ def symbolic_value(self) -> torch.Value:
195209 """The symbolic Value in torch.Graph."""
196210 return self ._torch_value
197211
212+ def value_info (self ) -> Optional [onnx .ValueInfoProto ]:
213+ try :
214+ dtype = self .onnx_dtype .value
215+ except torch .onnx .errors .OnnxExporterError :
216+ return None
217+ if dtype == onnx .TensorProto .UNDEFINED :
218+ return None
219+ return onnx .helper .make_tensor_value_info (self .name , dtype , self .shape )
220+
198221
199222@runtime_typing .checked
200223def _unwrap_tensor_to_torch_value (
@@ -223,7 +246,12 @@ def _unwrap_tensor_to_torch_value(
223246
224247@runtime_typing .checked
225248def _wrap_torch_value_to_tensor (
226- value : Union [torch .Value , Mapping [str , ValidTorchValueType ], Sequence [ValidTorchValueType ]]
249+ value : Union [
250+ torch .Value , Mapping [str , ValidTorchValueType ], Sequence [ValidTorchValueType ]
251+ ],
252+ * ,
253+ shape : Optional [Union [torch .Size , Tuple [Union [int , str , None ], ...]]] = None ,
254+ dtype : Optional [torch .dtype ] = None ,
227255) -> Union [
228256 ValidArgumentType ,
229257 Dict [str , ValidArgumentType ],
@@ -232,7 +260,12 @@ def _wrap_torch_value_to_tensor(
232260]:
233261 """Wrap torch.Value to TorchScriptTensor."""
234262 if isinstance (value , torch .Value ):
235- return TorchScriptTensor (value )
263+ tensor = TorchScriptTensor (value )
264+ if shape is not None :
265+ tensor .shape = shape
266+ if dtype is not None :
267+ tensor .dtype = dtype
268+ return tensor
236269 if isinstance (value , dict ):
237270 return {k : _wrap_torch_value_to_tensor (v ) for k , v in value .items ()} # type: ignore[misc,return-value]
238271 if isinstance (value , list ):
@@ -399,6 +432,16 @@ def __init__(
399432 self ._parent_torch_script_graph = parent_torch_script_graph
400433 # Domain name of the graph. None if this is the top level graph.
401434 self ._domain_name : Optional [str ] = domain_name
435+ # Mapping from `torch.Value` to `TorchScriptTensor`.
436+ # Because `torch.Value` does not provide API to set and retrieve symbolic shapes,
437+ # and because `TorchScriptTensor` is not accessible through the `torch.Graph` graph,
438+ # this mapping is used to keep track of the `TorchScriptTensor` associated with
439+ # `torch.Value`.
440+ # `TorchScriptTensor` records dtype and symbolic shapes.
441+ # This info is later serialized as `ValueInfoProto` inside ONNX, to
442+ # provide shape and dtype information for nodes within nested function calls.
443+ # https://github.com/onnx/onnx/issues/5487
444+ self ._value_to_tensor : Dict [torch .Value , TorchScriptTensor ] = {}
402445
403446 if self ._domain_name is None and self ._parent_torch_script_graph is not None :
404447 raise RuntimeError (
@@ -441,7 +484,7 @@ def domain_name(self) -> Optional[str]:
441484 def add_input (
442485 self ,
443486 input_name : Optional [str ],
444- shape : Optional [Union [torch .Size , Sequence [Union [int , str , None ]]]] = None ,
487+ shape : Optional [Union [torch .Size , Tuple [Union [int , str , None ], ... ]]] = None ,
445488 dtype : Optional [torch .dtype ] = None ,
446489 ) -> TorchScriptTensor :
447490 if input_name is None :
@@ -462,7 +505,11 @@ def add_input(
462505 [dim if isinstance (dim , int ) else None for dim in shape ] # type: ignore[union-attr]
463506 )
464507 )
465- tensor_value = _wrap_torch_value_to_tensor (torch_value )
508+ tensor_value = _wrap_torch_value_to_tensor (torch_value , shape = shape , dtype = dtype )
509+ if isinstance (tensor_value , TorchScriptTensor ):
510+ # NOTE: Only track value that maps to tensor.
511+ # Value that maps to Sequence/Dict of tensors is not tracked.
512+ self ._value_to_tensor [torch_value ] = tensor_value
466513 return tensor_value # type: ignore[return-value]
467514
468515 @runtime_typing .checked
@@ -486,16 +533,16 @@ def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor:
486533 self ._initializers_inputs_from_parent [
487534 name
488535 ] = self ._parent_torch_script_graph .add_initializer (name , value )
489- torch_value = self ._torch_graph .addInput (name )
490- torch_value .setType (torch .TensorType .create_from_tensor (value ))
491- tensor_value = _wrap_torch_value_to_tensor (torch_value )
492- self ._initializers_inputs [name ] = tensor_value # type: ignore[assignment]
493- return tensor_value # type: ignore[return-value]
536+ else :
537+ self ._initializers [name ] = value
494538
495- self ._initializers [name ] = value
496539 torch_value = self ._torch_graph .addInput (name )
497540 torch_value .setType (torch .TensorType .create_from_tensor (value ))
498- tensor_value = _wrap_torch_value_to_tensor (torch_value )
541+ tensor_value = _wrap_torch_value_to_tensor (
542+ torch_value , shape = value .shape , dtype = value .dtype
543+ )
544+ if isinstance (tensor_value , TorchScriptTensor ):
545+ self ._value_to_tensor [torch_value ] = tensor_value
499546 self ._initializers_inputs [name ] = tensor_value # type: ignore[assignment]
500547 return tensor_value # type: ignore[return-value]
501548
@@ -595,11 +642,16 @@ def _add_torchscript_op_call(
595642 n_outputs = n_outputs ,
596643 )
597644 assert result , "Expected at least one output from ONNX op call."
645+ # NOTE: TorchScriptTensor is created here, however neither dtype nor shape is
646+ # set. It is expected that exporter will modify the tensor being returned and
647+ # set these info.
598648 if len (result ) == 1 :
599649 tensor = TorchScriptTensor (result [0 ])
600650 tensor .name = _rename_intermediate_value (tensor .name )
651+ self ._value_to_tensor [result [0 ]] = tensor
601652 return tensor
602653 tensors = tuple (TorchScriptTensor (v ) for v in result )
654+ self ._value_to_tensor .update (dict (zip (result , tensors )))
603655 for tensor in tensors :
604656 tensor .name = _rename_intermediate_value (tensor .name )
605657 return tensors
@@ -634,6 +686,54 @@ def fetch_function_proto_dict(
634686 function_proto_dict [name_domain ] = function .to_function_proto ()
635687 return function_proto_dict
636688
689+ @runtime_typing .checked
690+ def _override_with_symbolic_value_info_proto (self , onnx_model : onnx .ModelProto ):
691+ existing_value_info = {info .name : info for info in onnx_model .graph .value_info }
692+
693+ # Override value_info for top level graph inputs.
694+ for input in self .torch_graph .inputs ():
695+ if input not in self ._value_to_tensor :
696+ raise RuntimeError (f"Input '{ input .debugName ()} ' has no type." )
697+ tensor = self ._value_to_tensor [input ]
698+ if (value_info := tensor .value_info ()) is None :
699+ continue
700+ for i , input_info in enumerate (onnx_model .graph .input ):
701+ if input_info .name == input .debugName ():
702+ onnx_model .graph .input .insert (i , value_info )
703+ onnx_model .graph .input .remove (input_info )
704+ break
705+
706+ # Override value_info for top level graph outputs.
707+ for output in self .torch_graph .outputs ():
708+ if output not in self ._value_to_tensor :
709+ raise RuntimeError (f"Output '{ output .debugName ()} ' has no type." )
710+ tensor = self ._value_to_tensor [output ]
711+ if (value_info := tensor .value_info ()) is None :
712+ continue
713+ for i , output_info in enumerate (onnx_model .graph .output ):
714+ if output_info .name == output .debugName ():
715+ onnx_model .graph .output .insert (i , value_info )
716+ onnx_model .graph .output .remove (output_info )
717+ break
718+
719+ # Remove existing static/incomplete value info.
720+ del onnx_model .graph .value_info [:]
721+
722+ # Insert value info for nodes within nested function calls.
723+ # NOTE: This is an experimental feature, since in official ONNX spec, nodes
724+ # within FunctionProto to have value info. https://github.com/onnx/onnx/issues/5487
725+ # The names for value info are generated uniquely to be retrievable based on
726+ # the call site and call stack.
727+ # The naming strategy is subject to change. Since all local functions representing
728+ # nn.Modules exported by dynamo exporter have unique call sites, their function
729+ # op_type name can serve to form the unique identifier for value info.
730+ function_value_infos = self .generate_function_value_info_proto ()
731+ # Override existing value info for nodes in top level graph.
732+ existing_value_info .update (function_value_infos )
733+ onnx_model .graph .value_info .extend (existing_value_info .values ())
734+
735+ return onnx_model
736+
637737 @runtime_typing .checked
638738 def add_op_call (
639739 self ,
@@ -692,6 +792,39 @@ def add_module_call(
692792 n_outputs = sub_torch_script_graph .num_outputs ,
693793 )
694794
795+ @runtime_typing .checked
796+ def generate_function_value_info_proto (
797+ self , prefix : str = ""
798+ ) -> Mapping [str , onnx .ValueInfoProto ]:
799+ """Unique naming strategies
800+
801+ {function1_op_type}/{function2_op_type}/.../{value_name}
802+
803+ As long as function op_type has unique call site, this is safe.
804+
805+ Preferably, the following is better
806+
807+ {node1_name}/{node2_name}/.../{value_name}
808+
809+ However, node name is an optional field generated on the fly during torchscript
810+ graph serialization to onnx model proto. Such info is not retrievable at this point.
811+ """
812+ named_value_info = {}
813+ for torch_value , tensor in self ._value_to_tensor .items ():
814+ name = torch_value .debugName ()
815+ if (value_info := tensor .value_info ()) is None :
816+ continue
817+ if prefix :
818+ name = f"{ prefix } /{ name } "
819+ named_value_info [name ] = value_info
820+ for name , sub_graph in self ._sub_torch_script_graphs .items ():
821+ named_value_info .update (
822+ sub_graph .generate_function_value_info_proto (
823+ f"{ prefix } /{ name } " if prefix else name
824+ )
825+ )
826+ return named_value_info
827+
695828 @runtime_typing .checked
696829 def to_function_proto (self , opset_version : int , function_name : str ) -> onnx .FunctionProto :
697830 assert len (self .initializers ) == 0 , "Model local functions cannot have initializers."
@@ -801,6 +934,9 @@ def to_model_proto(
801934 onnx_model .functions .extend (function_proto_dict .values ())
802935 onnx_model .functions .extend (_shared_functions ())
803936
937+ # Override value_infos with symbolic shapes.
938+ onnx_model = self ._override_with_symbolic_value_info_proto (onnx_model )
939+
804940 # `_export_onnx` only exports opset_imports that is visible to it. It does not
805941 # export opset_imports for nested functions, since it does not have access to
806942 # them. We manually add them back and merge with existing opset_imports in the
0 commit comments