81
81
_FUNCTION_VALUE_INFO_SUPPORTED_VERSION = (
82
82
10 # ONNX IR version where value info in functions was introduced
83
83
)
84
+ _QUANT_PARAMETER_TENSOR_NAMES_FIELD = "quant_parameter_tensor_names"
84
85
_T = typing .TypeVar ("_T" , bound = Callable [..., Any ])
85
86
86
87
@@ -586,6 +587,9 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
586
587
587
588
Returns:
588
589
IR Graph.
590
+
591
+ .. versionadded:: 0.3
592
+ Support for *quantization_annotation* is added.
589
593
"""
590
594
return _deserialize_graph (proto , [])
591
595
@@ -606,12 +610,21 @@ def _deserialize_graph(
606
610
Returns:
607
611
IR Graph.
608
612
"""
613
+ # Process TensorAnnotation for quantization
614
+ quantization_annotations = {
615
+ annotation .tensor_name : annotation for annotation in proto .quantization_annotation
616
+ }
617
+
609
618
# Create values for initializers and inputs
610
619
initializer_tensors = [deserialize_tensor (tensor ) for tensor in proto .initializer ]
611
620
inputs = [_core .Input (info .name ) for info in proto .input ]
612
621
for info , value in zip (proto .input , inputs ):
613
622
deserialize_value_info_proto (info , value )
614
623
624
+ # Add TensorAnnotation for inputs if they exist
625
+ if value .name in quantization_annotations :
626
+ _deserialize_quantization_annotation (quantization_annotations [value .name ], value )
627
+
615
628
# Initialize the values dictionary for this graph scope with the inputs and initializers
616
629
values : dict [str , _core .Value ] = {v .name : v for v in inputs } # type: ignore[misc]
617
630
scoped_values .append (values )
@@ -632,14 +645,21 @@ def _deserialize_graph(
632
645
type = _core .TensorType (tensor .dtype ),
633
646
const_value = tensor ,
634
647
)
648
+ if initializer_value .name in quantization_annotations :
649
+ _deserialize_quantization_annotation (
650
+ quantization_annotations [initializer_value .name ], initializer_value
651
+ )
635
652
values [tensor .name ] = initializer_value # type: ignore[index]
636
653
initializer_values .append (initializer_value )
637
654
638
655
# Add ValueInfos for this graph scope
639
656
value_info = {info .name : info for info in proto .value_info }
640
657
641
658
# Deserialize nodes with all known values
642
- nodes = [_deserialize_node (node , scoped_values , value_info ) for node in proto .node ]
659
+ nodes = [
660
+ _deserialize_node (node , scoped_values , value_info , quantization_annotations )
661
+ for node in proto .node
662
+ ]
643
663
644
664
# Fill in values for graph outputs
645
665
outputs = [deserialize_value_info_proto (info , values [info .name ]) for info in proto .output ]
@@ -662,7 +682,10 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
662
682
value_info = {info .name : info for info in getattr (proto , "value_info" , [])}
663
683
664
684
# TODO(justinchuby): Handle unsorted nodes
665
- nodes = [_deserialize_node (node , [values ], value_info = value_info ) for node in proto .node ]
685
+ nodes = [
686
+ _deserialize_node (node , [values ], value_info = value_info , quantization_annotations = {})
687
+ for node in proto .node
688
+ ]
666
689
outputs = [values [name ] for name in proto .output ]
667
690
graph = _core .Graph (
668
691
inputs ,
@@ -707,6 +730,19 @@ def deserialize_value_info_proto(
707
730
return value
708
731
709
732
733
+ @_capture_errors (lambda proto , value : str (proto ))
734
+ def _deserialize_quantization_annotation (
735
+ proto : onnx .TensorAnnotation , value : _core .Value
736
+ ) -> None :
737
+ """Deserialize a quantization_annotation as TensorAnnotation into a Value.
738
+
739
+ This function is marked private because we don't expect users to call it directly.
740
+ """
741
+ value .meta [_QUANT_PARAMETER_TENSOR_NAMES_FIELD ] = _deserialize_string_string_maps (
742
+ proto .quant_parameter_tensor_names
743
+ )
744
+
745
+
710
746
@_capture_errors (str )
711
747
def deserialize_tensor_shape (proto : onnx .TensorShapeProto ) -> _core .Shape :
712
748
# This logic handles when the shape is [] as well
@@ -844,6 +880,9 @@ def deserialize_metadata_props(
844
880
return {entry .key : entry .value for entry in proto }
845
881
846
882
883
+ _deserialize_string_string_maps = deserialize_metadata_props
884
+
885
+
847
886
def deserialize_attribute (proto : onnx .AttributeProto ) -> _core .Attr | _core .RefAttr :
848
887
return _deserialize_attribute (proto , [])
849
888
@@ -918,14 +957,17 @@ def _deserialize_attribute(
918
957
919
958
920
959
def deserialize_node (proto : onnx .NodeProto ) -> _core .Node :
921
- return _deserialize_node (proto , scoped_values = [], value_info = {})
960
+ return _deserialize_node (
961
+ proto , scoped_values = [], value_info = {}, quantization_annotations = {}
962
+ )
922
963
923
964
924
- @_capture_errors (lambda proto , scoped_values , value_info : str (proto ))
965
+ @_capture_errors (lambda proto , scoped_values , value_info , quantization_annotations : str (proto ))
925
966
def _deserialize_node (
926
967
proto : onnx .NodeProto ,
927
968
scoped_values : list [dict [str , _core .Value ]],
928
969
value_info : dict [str , onnx .ValueInfoProto ],
970
+ quantization_annotations : dict [str , onnx .TensorAnnotation ],
929
971
) -> _core .Node :
930
972
node_inputs : list [_core .Value | None ] = []
931
973
for input_name in proto .input :
@@ -968,6 +1010,10 @@ def _deserialize_node(
968
1010
# Fill in shape/type information if they exist
969
1011
if input_name in value_info :
970
1012
deserialize_value_info_proto (value_info [input_name ], value )
1013
+ if input_name in quantization_annotations :
1014
+ _deserialize_quantization_annotation (
1015
+ quantization_annotations [input_name ], value
1016
+ )
971
1017
node_inputs .append (value )
972
1018
# We can only create the value in the current scope. If the subgraph is
973
1019
# referencing a value that is not in the current scope, it is impossible
@@ -1009,6 +1055,8 @@ def _deserialize_node(
1009
1055
proto .name ,
1010
1056
proto .op_type ,
1011
1057
)
1058
+ if output_name in quantization_annotations :
1059
+ _deserialize_quantization_annotation (quantization_annotations [output_name ], value )
1012
1060
node_outputs .append (value )
1013
1061
return _core .Node (
1014
1062
proto .domain ,
@@ -1173,6 +1221,29 @@ def _serialize_metadata_props_into(
1173
1221
string_string_entries .add (key = key , value = from_ [key ])
1174
1222
1175
1223
1224
+ _serialize_string_string_maps = _serialize_metadata_props_into
1225
+
1226
+
1227
+ def _maybe_add_quantization_annotation (
1228
+ graph_proto : onnx .GraphProto , value : _protocols .ValueProtocol
1229
+ ) -> None :
1230
+ if quantization_annotation := value .meta .get (_QUANT_PARAMETER_TENSOR_NAMES_FIELD ):
1231
+ _serialize_tensor_annotation_into (
1232
+ graph_proto .quantization_annotation .add (), value .name , quantization_annotation
1233
+ )
1234
+
1235
+
1236
+ def _serialize_tensor_annotation_into (
1237
+ tensor_annotation_proto : onnx .TensorAnnotation ,
1238
+ tensor_name : str ,
1239
+ quant_parameter_tensor_names : dict [str , str ],
1240
+ ) -> None :
1241
+ tensor_annotation_proto .tensor_name = tensor_name
1242
+ _serialize_string_string_maps (
1243
+ tensor_annotation_proto .quant_parameter_tensor_names , quant_parameter_tensor_names
1244
+ )
1245
+
1246
+
1176
1247
def serialize_graph (
1177
1248
graph : _protocols .GraphProtocol | _protocols .GraphViewProtocol ,
1178
1249
) -> onnx .GraphProto :
@@ -1208,8 +1279,14 @@ def serialize_graph_into(
1208
1279
graph_proto .doc_string = from_ .doc_string
1209
1280
for input_ in from_ .inputs :
1210
1281
serialize_value_into (graph_proto .input .add (), input_ )
1282
+ if input_ .name not in from_ .initializers :
1283
+ # Annotations for initializers will be added below to avoid double adding
1284
+ # TODO(justinchuby): We should add a method is_initializer() on Value when
1285
+ # the initializer list is tracked
1286
+ _maybe_add_quantization_annotation (graph_proto , input_ )
1211
1287
# TODO(justinchuby): Support sparse_initializer
1212
1288
for initializer in from_ .initializers .values ():
1289
+ _maybe_add_quantization_annotation (graph_proto , initializer )
1213
1290
if initializer .const_value is None :
1214
1291
# Skip initializers without constant values
1215
1292
logger .warning (
@@ -1222,15 +1299,18 @@ def serialize_graph_into(
1222
1299
for node in from_ :
1223
1300
serialize_node_into (graph_proto .node .add (), from_ = node )
1224
1301
for node_output in node .outputs :
1225
- if not _should_create_value_info_for_value (node_output ):
1226
- # No need to serialize value info if it is not set
1227
- continue
1228
1302
if node_output .is_graph_output ():
1229
- # No need to serialize value info for these outputs because they are also graph outputs
1303
+ # No need to serialize info for these outputs because they are handled as graph outputs
1304
+ continue
1305
+ _maybe_add_quantization_annotation (graph_proto , node_output )
1306
+ if not _should_create_value_info_for_value (node_output ): # pylint: disable=no-else-continue
1307
+ # No need to serialize value info if it is not set
1230
1308
continue
1231
- serialize_value_into (graph_proto .value_info .add (), node_output )
1309
+ else :
1310
+ serialize_value_into (graph_proto .value_info .add (), node_output )
1232
1311
for output in from_ .outputs :
1233
1312
serialize_value_into (graph_proto .output .add (), from_ = output )
1313
+ _maybe_add_quantization_annotation (graph_proto , output )
1234
1314
if from_ .metadata_props :
1235
1315
_serialize_metadata_props_into (graph_proto .metadata_props , from_ .metadata_props )
1236
1316
0 commit comments