@@ -90,11 +90,14 @@ def _rename_intermediate_value(name: str) -> str:
90
90
class TorchScriptTensor (onnxscript_tensor .Tensor ):
91
91
"""A onnxscript tensor that wraps a torchscript Value."""
92
92
93
- def __init__ (self , value : torch .Value ):
93
+ def __init__ (
94
+ self ,
95
+ value : torch .Value ,
96
+ ):
94
97
super ().__init__ (None )
95
98
self ._torch_value : torch .Value = value
96
99
self ._concrete_value : Optional [np .ndarray ] = None
97
- self ._shape : Optional [Tuple [int | None , ...]] = None
100
+ self ._shape : Optional [Tuple [int | str | None , ...]] = None
98
101
self ._torch_dtype : Optional [torch .dtype ] = None
99
102
self ._name : Optional [str ] = None
100
103
self ._is_complex : bool = False
@@ -125,14 +128,17 @@ def name(self, name: str):
125
128
126
129
@property # type: ignore[override]
127
130
def rank (self ) -> int | None :
131
+ if self ._shape is not None :
132
+ return len (self ._shape )
133
+
128
134
value_type = self ._torch_value .type ()
129
135
if value_type is None :
130
136
return None
131
137
value_type = typing .cast (torch .TensorType , value_type )
132
138
return value_type .dim ()
133
139
134
140
@property # type: ignore[override]
135
- def shape (self ) -> Tuple [int | None , ...] | None :
141
+ def shape (self ) -> Tuple [int | str | None , ...] | None :
136
142
if self ._shape is not None :
137
143
return self ._shape
138
144
@@ -149,9 +155,17 @@ def shape(self) -> Tuple[int | None, ...] | None:
149
155
return tuple (shape )
150
156
151
157
@shape .setter
152
- def shape (self , shape : Tuple [int | None , ...]):
153
- self ._shape = shape
154
- self ._torch_value .setType (self ._torch_value .type ().with_sizes (list (shape )))
158
+ def shape (self , shape : Union [torch .Size , Tuple [int | str | None , ...]]):
159
+ # Normalize torch symbolic dimension size to str.
160
+ torch_sym_types = (torch .SymInt , torch .SymFloat , torch .SymBool )
161
+ self ._shape = tuple (
162
+ str (dim .node ) if isinstance (dim , torch_sym_types ) else dim # type: ignore[union-attr]
163
+ for dim in shape
164
+ )
165
+ # jit api does not support assigning symbolic shapes,
166
+ # hence symbols are replaced as None.
167
+ jit_shape = tuple (dim if isinstance (dim , int ) else None for dim in shape )
168
+ self ._torch_value .setType (self ._torch_value .type ().with_sizes (list (jit_shape )))
155
169
156
170
@property # type: ignore[override]
157
171
def dtype (self ) -> torch .dtype | None :
@@ -195,6 +209,15 @@ def symbolic_value(self) -> torch.Value:
195
209
"""The symbolic Value in torch.Graph."""
196
210
return self ._torch_value
197
211
212
+ def value_info (self ) -> Optional [onnx .ValueInfoProto ]:
213
+ try :
214
+ dtype = self .onnx_dtype .value
215
+ except torch .onnx .errors .OnnxExporterError :
216
+ return None
217
+ if dtype == onnx .TensorProto .UNDEFINED :
218
+ return None
219
+ return onnx .helper .make_tensor_value_info (self .name , dtype , self .shape )
220
+
198
221
199
222
@runtime_typing .checked
200
223
def _unwrap_tensor_to_torch_value (
@@ -223,7 +246,12 @@ def _unwrap_tensor_to_torch_value(
223
246
224
247
@runtime_typing .checked
225
248
def _wrap_torch_value_to_tensor (
226
- value : Union [torch .Value , Mapping [str , ValidTorchValueType ], Sequence [ValidTorchValueType ]]
249
+ value : Union [
250
+ torch .Value , Mapping [str , ValidTorchValueType ], Sequence [ValidTorchValueType ]
251
+ ],
252
+ * ,
253
+ shape : Optional [Union [torch .Size , Tuple [Union [int , str , None ], ...]]] = None ,
254
+ dtype : Optional [torch .dtype ] = None ,
227
255
) -> Union [
228
256
ValidArgumentType ,
229
257
Dict [str , ValidArgumentType ],
@@ -232,7 +260,12 @@ def _wrap_torch_value_to_tensor(
232
260
]:
233
261
"""Wrap torch.Value to TorchScriptTensor."""
234
262
if isinstance (value , torch .Value ):
235
- return TorchScriptTensor (value )
263
+ tensor = TorchScriptTensor (value )
264
+ if shape is not None :
265
+ tensor .shape = shape
266
+ if dtype is not None :
267
+ tensor .dtype = dtype
268
+ return tensor
236
269
if isinstance (value , dict ):
237
270
return {k : _wrap_torch_value_to_tensor (v ) for k , v in value .items ()} # type: ignore[misc,return-value]
238
271
if isinstance (value , list ):
@@ -444,6 +477,16 @@ def __init__(
444
477
self ._parent_torch_script_graph = parent_torch_script_graph
445
478
# Domain name of the graph. None if this is the top level graph.
446
479
self ._domain_name : Optional [str ] = domain_name
480
+ # Mapping from `torch.Value` to `TorchScriptTensor`.
481
+ # Because `torch.Value` does not provide API to set and retrieve symbolic shapes,
482
+ # and because `TorchScriptTensor` is not accessible through the `torch.Graph` graph,
483
+ # this mapping is used to keep track of the `TorchScriptTensor` associated with
484
+ # `torch.Value`.
485
+ # `TorchScriptTensor` records dtype and symbolic shapes.
486
+ # This info is later serialized as `ValueInfoProto` inside ONNX, to
487
+ # provide shape and dtype information for nodes within nested function calls.
488
+ # https://github.com/onnx/onnx/issues/5487
489
+ self ._value_to_tensor : Dict [torch .Value , TorchScriptTensor ] = {}
447
490
448
491
if self ._domain_name is None and self ._parent_torch_script_graph is not None :
449
492
raise RuntimeError (
@@ -486,7 +529,7 @@ def domain_name(self) -> Optional[str]:
486
529
def add_input (
487
530
self ,
488
531
input_name : Optional [str ],
489
- shape : Optional [Union [torch .Size , Sequence [Union [int , str , None ]]]] = None ,
532
+ shape : Optional [Union [torch .Size , Tuple [Union [int , str , None ], ... ]]] = None ,
490
533
dtype : Optional [torch .dtype ] = None ,
491
534
) -> TorchScriptTensor :
492
535
if input_name is None :
@@ -507,7 +550,11 @@ def add_input(
507
550
[dim if isinstance (dim , int ) else None for dim in shape ] # type: ignore[union-attr]
508
551
)
509
552
)
510
- tensor_value = _wrap_torch_value_to_tensor (torch_value )
553
+ tensor_value = _wrap_torch_value_to_tensor (torch_value , shape = shape , dtype = dtype )
554
+ if isinstance (tensor_value , TorchScriptTensor ):
555
+ # NOTE: Only track value that maps to tensor.
556
+ # Value that maps to Sequence/Dict of tensors is not tracked.
557
+ self ._value_to_tensor [torch_value ] = tensor_value
511
558
return tensor_value # type: ignore[return-value]
512
559
513
560
@runtime_typing .checked
@@ -531,16 +578,16 @@ def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor:
531
578
self ._initializers_inputs_from_parent [
532
579
name
533
580
] = self ._parent_torch_script_graph .add_initializer (name , value )
534
- torch_value = self ._torch_graph .addInput (name )
535
- torch_value .setType (torch .TensorType .create_from_tensor (value ))
536
- tensor_value = _wrap_torch_value_to_tensor (torch_value )
537
- self ._initializers_inputs [name ] = tensor_value # type: ignore[assignment]
538
- return tensor_value # type: ignore[return-value]
581
+ else :
582
+ self ._initializers [name ] = value
539
583
540
- self ._initializers [name ] = value
541
584
torch_value = self ._torch_graph .addInput (name )
542
585
torch_value .setType (torch .TensorType .create_from_tensor (value ))
543
- tensor_value = _wrap_torch_value_to_tensor (torch_value )
586
+ tensor_value = _wrap_torch_value_to_tensor (
587
+ torch_value , shape = value .shape , dtype = value .dtype
588
+ )
589
+ if isinstance (tensor_value , TorchScriptTensor ):
590
+ self ._value_to_tensor [torch_value ] = tensor_value
544
591
self ._initializers_inputs [name ] = tensor_value # type: ignore[assignment]
545
592
return tensor_value # type: ignore[return-value]
546
593
@@ -640,11 +687,16 @@ def _add_torchscript_op_call(
640
687
n_outputs = n_outputs ,
641
688
)
642
689
assert result , "Expected at least one output from ONNX op call."
690
+ # NOTE: TorchScriptTensor is created here, however neither dtype nor shape is
691
+ # set. It is expected that exporter will modify the tensor being returned and
692
+ # set these info.
643
693
if len (result ) == 1 :
644
694
tensor = TorchScriptTensor (result [0 ])
645
695
tensor .name = _rename_intermediate_value (tensor .name )
696
+ self ._value_to_tensor [result [0 ]] = tensor
646
697
return tensor
647
698
tensors = tuple (TorchScriptTensor (v ) for v in result )
699
+ self ._value_to_tensor .update (dict (zip (result , tensors )))
648
700
for tensor in tensors :
649
701
tensor .name = _rename_intermediate_value (tensor .name )
650
702
return tensors
@@ -679,6 +731,54 @@ def fetch_function_proto_dict(
679
731
function_proto_dict [name_domain ] = function .to_function_proto ()
680
732
return function_proto_dict
681
733
734
+ @runtime_typing .checked
735
+ def _override_with_symbolic_value_info_proto (self , onnx_model : onnx .ModelProto ):
736
+ existing_value_info = {info .name : info for info in onnx_model .graph .value_info }
737
+
738
+ # Override value_info for top level graph inputs.
739
+ for input in self .torch_graph .inputs ():
740
+ if input not in self ._value_to_tensor :
741
+ raise RuntimeError (f"Input '{ input .debugName ()} ' has no type." )
742
+ tensor = self ._value_to_tensor [input ]
743
+ if (value_info := tensor .value_info ()) is None :
744
+ continue
745
+ for i , input_info in enumerate (onnx_model .graph .input ):
746
+ if input_info .name == input .debugName ():
747
+ onnx_model .graph .input .insert (i , value_info )
748
+ onnx_model .graph .input .remove (input_info )
749
+ break
750
+
751
+ # Override value_info for top level graph outputs.
752
+ for output in self .torch_graph .outputs ():
753
+ if output not in self ._value_to_tensor :
754
+ raise RuntimeError (f"Output '{ output .debugName ()} ' has no type." )
755
+ tensor = self ._value_to_tensor [output ]
756
+ if (value_info := tensor .value_info ()) is None :
757
+ continue
758
+ for i , output_info in enumerate (onnx_model .graph .output ):
759
+ if output_info .name == output .debugName ():
760
+ onnx_model .graph .output .insert (i , value_info )
761
+ onnx_model .graph .output .remove (output_info )
762
+ break
763
+
764
+ # Remove existing static/incomplete value info.
765
+ del onnx_model .graph .value_info [:]
766
+
767
+ # Insert value info for nodes within nested function calls.
768
+ # NOTE: This is an experimental feature, since in official ONNX spec, nodes
769
+ # within FunctionProto to have value info. https://github.com/onnx/onnx/issues/5487
770
+ # The names for value info are generated uniquely to be retrievable based on
771
+ # the call site and call stack.
772
+ # The naming strategy is subject to change. Since all local functions representing
773
+ # nn.Modules exported by dynamo exporter have unique call sites, their function
774
+ # op_type name can serve to form the unique identifier for value info.
775
+ function_value_infos = self .generate_function_value_info_proto ()
776
+ # Override existing value info for nodes in top level graph.
777
+ existing_value_info .update (function_value_infos )
778
+ onnx_model .graph .value_info .extend (existing_value_info .values ())
779
+
780
+ return onnx_model
781
+
682
782
@runtime_typing .checked
683
783
def add_op_call (
684
784
self ,
@@ -737,6 +837,39 @@ def add_module_call(
737
837
n_outputs = sub_torch_script_graph .num_outputs ,
738
838
)
739
839
840
+ @runtime_typing .checked
841
+ def generate_function_value_info_proto (
842
+ self , prefix : str = ""
843
+ ) -> Mapping [str , onnx .ValueInfoProto ]:
844
+ """Unique naming strategies
845
+
846
+ {function1_op_type}/{function2_op_type}/.../{value_name}
847
+
848
+ As long as function op_type has unique call site, this is safe.
849
+
850
+ Preferably, the following is better
851
+
852
+ {node1_name}/{node2_name}/.../{value_name}
853
+
854
+ However, node name is an optional field generated on the fly during torchscript
855
+ graph serialization to onnx model proto. Such info is not retrievable at this point.
856
+ """
857
+ named_value_info = {}
858
+ for torch_value , tensor in self ._value_to_tensor .items ():
859
+ name = torch_value .debugName ()
860
+ if (value_info := tensor .value_info ()) is None :
861
+ continue
862
+ if prefix :
863
+ name = f"{ prefix } /{ name } "
864
+ named_value_info [name ] = value_info
865
+ for name , sub_graph in self ._sub_torch_script_graphs .items ():
866
+ named_value_info .update (
867
+ sub_graph .generate_function_value_info_proto (
868
+ f"{ prefix } /{ name } " if prefix else name
869
+ )
870
+ )
871
+ return named_value_info
872
+
740
873
@runtime_typing .checked
741
874
def to_function_proto (self , opset_version : int , function_name : str ) -> onnx .FunctionProto :
742
875
assert len (self .initializers ) == 0 , "Model local functions cannot have initializers."
@@ -846,6 +979,9 @@ def to_model_proto(
846
979
onnx_model .functions .extend (function_proto_dict .values ())
847
980
onnx_model .functions .extend (_shared_functions ())
848
981
982
+ # Override value_infos with symbolic shapes.
983
+ onnx_model = self ._override_with_symbolic_value_info_proto (onnx_model )
984
+
849
985
# `_export_onnx` only exports opset_imports that is visible to it. It does not
850
986
# export opset_imports for nested functions, since it does not have access to
851
987
# them. We manually add them back and merge with existing opset_imports in the
0 commit comments