From 3e347fe78b4d8b0bc8a9a340bf03dced37218342 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Feb 2025 17:51:40 -0800 Subject: [PATCH 1/9] [IR] Add support for quant_parameter_tensor_names field --- onnxscript/ir/serde.py | 71 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index b333df8233..6357d64a4e 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -81,6 +81,7 @@ _FUNCTION_VALUE_INFO_SUPPORTED_VERSION = ( 10 # ONNX IR version where value info in functions was introduced ) +_QUANT_PARAMETER_TENSOR_NAMES_FIELD = "quant_parameter_tensor_names" _T = typing.TypeVar("_T", bound=Callable[..., Any]) @@ -638,8 +639,16 @@ def _deserialize_graph( # Add ValueInfos for this graph scope value_info = {info.name: info for info in proto.value_info} + # Add TensorAnnotation for quantization + quantization_annotations = { + annotation.tensor_name: annotation for annotation in proto.quantization_annotation + } + # Deserialize nodes with all known values - nodes = [_deserialize_node(node, scoped_values, value_info) for node in proto.node] + nodes = [ + _deserialize_node(node, scoped_values, value_info, quantization_annotations) + for node in proto.node + ] # Fill in values for graph outputs outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output] @@ -707,6 +716,21 @@ def deserialize_value_info_proto( return value +@_capture_errors(lambda proto, value: str(proto)) +def _deserialize_quantization_annotation( + proto: onnx.TensorAnnotation, value: _core.Value +) -> _core.Value: + """Deserialize a quantization_annotation as TensorAnnotation into a Value. + + This function is marked private because we don't expect users to call it directly. + """ + assert proto.tensor_name == value.name + value.meta[_QUANT_PARAMETER_TENSOR_NAMES_FIELD] = _deserialize_string_string_maps( + proto.quant_parameter_tensor_names + ) + return value + + @_capture_errors(str) def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape: # This logic handles when the shape is [] as well @@ -844,6 +868,9 @@ def deserialize_metadata_props( return {entry.key: entry.value for entry in proto} +_deserialize_string_string_maps = deserialize_metadata_props + + def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefAttr: return _deserialize_attribute(proto, []) @@ -918,7 +945,9 @@ def _deserialize_attribute( def deserialize_node(proto: onnx.NodeProto) -> _core.Node: - return _deserialize_node(proto, scoped_values=[], value_info={}) + return _deserialize_node( + proto, scoped_values=[], value_info={}, quantization_annotations={} + ) @_capture_errors(lambda proto, scoped_values, value_info: str(proto)) @@ -926,6 +955,7 @@ def _deserialize_node( proto: onnx.NodeProto, scoped_values: list[dict[str, _core.Value]], value_info: dict[str, onnx.ValueInfoProto], + quantization_annotations: dict[str, onnx.TensorAnnotation], ) -> _core.Node: node_inputs: list[_core.Value | None] = [] for input_name in proto.input: @@ -968,6 +998,10 @@ def _deserialize_node( # Fill in shape/type information if they exist if input_name in value_info: deserialize_value_info_proto(value_info[input_name], value) + if input_name in quantization_annotations: + _deserialize_quantization_annotation( + quantization_annotations[input_name], value + ) node_inputs.append(value) # We can only create the value in the current scope. If the subgraph is # referencing a value that is not in the current scope, it is impossible @@ -1002,6 +1036,8 @@ def _deserialize_node( # Fill in shape/type information if they exist if output_name in value_info: deserialize_value_info_proto(value_info[output_name], value) + if output_name in quantization_annotations: + _deserialize_quantization_annotation(quantization_annotations[output_name], value) else: logger.debug( "ValueInfoProto not found for output '%s' in node '%s' of type '%s'", @@ -1173,6 +1209,29 @@ def _serialize_metadata_props_into( string_string_entries.add(key=key, value=from_[key]) +_serialize_string_string_maps = _serialize_metadata_props_into + + +def _maybe_add_quantization_annotation( + graph_proto: onnx.GraphProto, value: _protocols.ValueProtocol +) -> None: + if quantization_annotation := value.meta.get(_QUANT_PARAMETER_TENSOR_NAMES_FIELD): + _serialize_tensor_annotation_into( + graph_proto.quantization_annotation.add(), value.name, quantization_annotation + ) + + +def _serialize_tensor_annotation_into( + tensor_annotation_proto: onnx.TensorAnnotation, + tensor_name: str, + quant_parameter_tensor_names: dict[str, str], +) -> None: + tensor_annotation_proto.tensor_name = tensor_name + _serialize_string_string_maps( + tensor_annotation_proto.quant_parameter_tensor_names, quant_parameter_tensor_names + ) + + def serialize_graph( graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol, ) -> onnx.GraphProto: @@ -1208,8 +1267,14 @@ def serialize_graph_into( graph_proto.doc_string = from_.doc_string for input_ in from_.inputs: serialize_value_into(graph_proto.input.add(), input_) + if input_.name not in from_.initializers: + # Annotations for initializers will be added below to avoid double adding + # TODO(justinchuby): We should add a method is_initializer() on Value when + # the initializer list is tracked + _maybe_add_quantization_annotation(graph_proto, input_) # TODO(justinchuby): Support sparse_initializer for initializer in from_.initializers.values(): + _maybe_add_quantization_annotation(graph_proto, initializer) if initializer.const_value is None: # Skip initializers without constant values logger.warning( @@ -1229,8 +1294,10 @@ def serialize_graph_into( # No need to serialize value info for these outputs because they are also graph outputs continue serialize_value_into(graph_proto.value_info.add(), node_output) + _maybe_add_quantization_annotation(graph_proto, node_output) for output in from_.outputs: serialize_value_into(graph_proto.output.add(), from_=output) + _maybe_add_quantization_annotation(graph_proto, output) if from_.metadata_props: _serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props) From 60c8ae03ed334d7b3c3fd1ed9ce4df8063c1ae01 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Feb 2025 17:56:23 -0800 Subject: [PATCH 2/9] docs --- onnxscript/ir/_protocols.py | 6 +++++- onnxscript/ir/serde.py | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index 70ac849c90..9d038602fc 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -277,6 +277,11 @@ class GraphProtocol(Protocol): seen as a Sequence of nodes and should be used as such. For example, to obtain all nodes as a list, call ``list(graph)``. + .. :note:: + ``quantization_annotation`` is deserialized into the Value's ``meta`` field + under the ``quant_parameter_tensor_names`` key. Values that are stored + under this key will be serialized as quantization annotations. + Attributes: name: The name of the graph. inputs: The input values of the graph. @@ -288,7 +293,6 @@ class GraphProtocol(Protocol): meta: Metadata store for graph transform passes. """ - # TODO(justinchuby): Support quantization_annotation name: str | None inputs: MutableSequence[ValueProtocol] outputs: MutableSequence[ValueProtocol] diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 6357d64a4e..915559db7d 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -587,6 +587,9 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph: Returns: IR Graph. + + .. versionadded:: 0.3 + Support for *quantization_annotation* is added. """ return _deserialize_graph(proto, []) From 8eab3bf1aa9af88db458bd50facb14d1e04df6b3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Feb 2025 18:05:27 -0800 Subject: [PATCH 3/9] function --- onnxscript/ir/serde.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 915559db7d..0cc64bc516 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -674,7 +674,10 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: value_info = {info.name: info for info in getattr(proto, "value_info", [])} # TODO(justinchuby): Handle unsorted nodes - nodes = [_deserialize_node(node, [values], value_info=value_info) for node in proto.node] + nodes = [ + _deserialize_node(node, [values], value_info=value_info, quantization_annotations={}) + for node in proto.node + ] outputs = [values[name] for name in proto.output] graph = _core.Graph( inputs, From df0354eaba89756da89410828c89c7aa599c5c4f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Feb 2025 18:06:10 -0800 Subject: [PATCH 4/9] error notes --- onnxscript/ir/serde.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 0cc64bc516..4a5502a477 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -956,7 +956,7 @@ def deserialize_node(proto: onnx.NodeProto) -> _core.Node: ) -@_capture_errors(lambda proto, scoped_values, value_info: str(proto)) +@_capture_errors(lambda proto, scoped_values, value_info, quantization_annotations: str(proto)) def _deserialize_node( proto: onnx.NodeProto, scoped_values: list[dict[str, _core.Value]], From 97aeb1929148b6e67b2b06d42333febc51fdeaa5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 27 Feb 2025 18:08:27 -0800 Subject: [PATCH 5/9] order --- onnxscript/ir/serde.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 4a5502a477..2da9d8c38d 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -1042,8 +1042,6 @@ def _deserialize_node( # Fill in shape/type information if they exist if output_name in value_info: deserialize_value_info_proto(value_info[output_name], value) - if output_name in quantization_annotations: - _deserialize_quantization_annotation(quantization_annotations[output_name], value) else: logger.debug( "ValueInfoProto not found for output '%s' in node '%s' of type '%s'", @@ -1051,6 +1049,8 @@ def _deserialize_node( proto.name, proto.op_type, ) + if output_name in quantization_annotations: + _deserialize_quantization_annotation(quantization_annotations[output_name], value) node_outputs.append(value) return _core.Node( proto.domain, From bca7eed80b34b0aeeede275073420825bc23cbe0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 3 Mar 2025 10:57:59 -0800 Subject: [PATCH 6/9] test --- onnxscript/ir/serde.py | 29 ++++++---- onnxscript/ir/serde_test.py | 106 ++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 10 deletions(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 2da9d8c38d..69628098cb 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -610,12 +610,21 @@ def _deserialize_graph( Returns: IR Graph. """ + # Process TensorAnnotation for quantization + quantization_annotations = { + annotation.tensor_name: annotation for annotation in proto.quantization_annotation + } + # Create values for initializers and inputs initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer] inputs = [_core.Input(info.name) for info in proto.input] for info, value in zip(proto.input, inputs): deserialize_value_info_proto(info, value) + # Add TensorAnnotation for inputs if they exist + if value.name in quantization_annotations: + _deserialize_quantization_annotation(quantization_annotations[value.name], value) + # Initialize the values dictionary for this graph scope with the inputs and initializers values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc] scoped_values.append(values) @@ -636,17 +645,16 @@ def _deserialize_graph( type=_core.TensorType(tensor.dtype), const_value=tensor, ) + if initializer_value.name in quantization_annotations: + _deserialize_quantization_annotation( + quantization_annotations[initializer_value.name], initializer_value + ) values[tensor.name] = initializer_value # type: ignore[index] initializer_values.append(initializer_value) # Add ValueInfos for this graph scope value_info = {info.name: info for info in proto.value_info} - # Add TensorAnnotation for quantization - quantization_annotations = { - annotation.tensor_name: annotation for annotation in proto.quantization_annotation - } - # Deserialize nodes with all known values nodes = [ _deserialize_node(node, scoped_values, value_info, quantization_annotations) @@ -1293,14 +1301,15 @@ def serialize_graph_into( for node in from_: serialize_node_into(graph_proto.node.add(), from_=node) for node_output in node.outputs: - if not _should_create_value_info_for_value(node_output): - # No need to serialize value info if it is not set - continue if node_output.is_graph_output(): - # No need to serialize value info for these outputs because they are also graph outputs + # No need to serialize info for these outputs because they are handled as graph outputs continue - serialize_value_into(graph_proto.value_info.add(), node_output) _maybe_add_quantization_annotation(graph_proto, node_output) + if not _should_create_value_info_for_value(node_output): + # No need to serialize value info if it is not set + continue + else: + serialize_value_into(graph_proto.value_info.add(), node_output) for output in from_.outputs: serialize_value_into(graph_proto.output.add(), from_=output) _maybe_add_quantization_annotation(graph_proto, output) diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py index f46756055e..b4d13ebdea 100644 --- a/onnxscript/ir/serde_test.py +++ b/onnxscript/ir/serde_test.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import unittest +import google.protobuf.text_format import ml_dtypes import numpy as np import onnx @@ -290,5 +291,110 @@ def test_deserialize_graph_handles_unsorted_graph(self): self.assertEqual(deserialized_graph[1].op_type, "Op_0") +class QuantizationAnnotationTest(unittest.TestCase): + """Test that quantization annotations are correctly serialized and deserialized.""" + + def setUp(self): + model_text = """\ +ir_version: 8 +producer_name: "pytorch" +producer_version: "2.1.1" +graph { + input { + name: "input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } + node { + input: "input" + output: "intermediate_value" + op_type: "TestOp1" + domain: "test_domain" + } + node { + input: "intermediate_value" + output: "output" + op_type: "TestOp2" + domain: "test_domain" + } + quantization_annotation { + tensor_name: "input" + quant_parameter_tensor_names { + key: "custom_key" + value: "arbitrary_value_input" + } + } + quantization_annotation { + tensor_name: "intermediate_value" + quant_parameter_tensor_names { + key: "custom_key" + value: "arbitrary_value_intermediate" + } + } + quantization_annotation { + tensor_name: "output" + quant_parameter_tensor_names { + key: "custom_key" + value: "arbitrary_value_output" + } + } +}""" + self.model = onnx.ModelProto() + google.protobuf.text_format.Parse(model_text, self.model) + + def test_deserialize_quantization_annotation(self): + model = serde.deserialize_model(self.model) + self.assertEqual( + model.graph.inputs[0].meta["quant_parameter_tensor_names"], + {"custom_key": "arbitrary_value_input"}, + ) + self.assertEqual( + model.graph.node(0).outputs[0].meta["quant_parameter_tensor_names"], + {"custom_key": "arbitrary_value_intermediate"}, + ) + self.assertEqual( + model.graph.outputs[0].meta["quant_parameter_tensor_names"], + {"custom_key": "arbitrary_value_output"}, + ) + + def test_serde_roundtrip(self): + model = serde.deserialize_model(self.model) + serialized_model = serde.serialize_model(model) + deserialized_model = serde.deserialize_model(serialized_model) + self.assertEqual( + deserialized_model.graph.inputs[0].meta["quant_parameter_tensor_names"], + {"custom_key": "arbitrary_value_input"}, + ) + self.assertEqual( + deserialized_model.graph.node(0).outputs[0].meta["quant_parameter_tensor_names"], + {"custom_key": "arbitrary_value_intermediate"}, + ) + self.assertEqual( + deserialized_model.graph.outputs[0].meta["quant_parameter_tensor_names"], + {"custom_key": "arbitrary_value_output"}, + ) + + if __name__ == "__main__": unittest.main() From 791f71b8a88b66fd23a3c71ed1d42b21db386886 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 3 Mar 2025 11:00:24 -0800 Subject: [PATCH 7/9] assertion --- onnxscript/ir/serde.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 69628098cb..c4e5d02b49 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -738,7 +738,6 @@ def _deserialize_quantization_annotation( This function is marked private because we don't expect users to call it directly. """ - assert proto.tensor_name == value.name value.meta[_QUANT_PARAMETER_TENSOR_NAMES_FIELD] = _deserialize_string_string_maps( proto.quant_parameter_tensor_names ) From e45ae8dc5c5592321acf8b69ccff8bdc123ecdd0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 3 Mar 2025 11:00:45 -0800 Subject: [PATCH 8/9] return --- onnxscript/ir/serde.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index c4e5d02b49..b9ed646422 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -733,7 +733,7 @@ def deserialize_value_info_proto( @_capture_errors(lambda proto, value: str(proto)) def _deserialize_quantization_annotation( proto: onnx.TensorAnnotation, value: _core.Value -) -> _core.Value: +) -> None: """Deserialize a quantization_annotation as TensorAnnotation into a Value. This function is marked private because we don't expect users to call it directly. @@ -741,7 +741,6 @@ def _deserialize_quantization_annotation( value.meta[_QUANT_PARAMETER_TENSOR_NAMES_FIELD] = _deserialize_string_string_maps( proto.quant_parameter_tensor_names ) - return value @_capture_errors(str) From f38c45f1f032e6082235b77d1787f2faeb26f9b4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 5 Mar 2025 10:41:36 -0800 Subject: [PATCH 9/9] lint --- onnxscript/ir/serde.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index b9ed646422..4988562030 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -1303,7 +1303,7 @@ def serialize_graph_into( # No need to serialize info for these outputs because they are handled as graph outputs continue _maybe_add_quantization_annotation(graph_proto, node_output) - if not _should_create_value_info_for_value(node_output): + if not _should_create_value_info_for_value(node_output): # pylint: disable=no-else-continue # No need to serialize value info if it is not set continue else: