Skip to content

[IR] Add support for quant_parameter_tensor_names field #2080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion onnxscript/ir/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ class GraphProtocol(Protocol):
seen as a Sequence of nodes and should be used as such. For example, to obtain
all nodes as a list, call ``list(graph)``.

.. :note::
``quantization_annotation`` is deserialized into the Value's ``meta`` field
under the ``quant_parameter_tensor_names`` key. Values that are stored
under this key will be serialized as quantization annotations.

Attributes:
name: The name of the graph.
inputs: The input values of the graph.
Expand All @@ -288,7 +293,6 @@ class GraphProtocol(Protocol):
meta: Metadata store for graph transform passes.
"""

# TODO(justinchuby): Support quantization_annotation
name: str | None
inputs: MutableSequence[ValueProtocol]
outputs: MutableSequence[ValueProtocol]
Expand Down
98 changes: 89 additions & 9 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
_FUNCTION_VALUE_INFO_SUPPORTED_VERSION = (
10 # ONNX IR version where value info in functions was introduced
)
_QUANT_PARAMETER_TENSOR_NAMES_FIELD = "quant_parameter_tensor_names"
_T = typing.TypeVar("_T", bound=Callable[..., Any])


Expand Down Expand Up @@ -586,6 +587,9 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:

Returns:
IR Graph.

.. versionadded:: 0.3
Support for *quantization_annotation* is added.
"""
return _deserialize_graph(proto, [])

Expand All @@ -606,12 +610,21 @@ def _deserialize_graph(
Returns:
IR Graph.
"""
# Process TensorAnnotation for quantization
quantization_annotations = {
annotation.tensor_name: annotation for annotation in proto.quantization_annotation
}

# Create values for initializers and inputs
initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
inputs = [_core.Input(info.name) for info in proto.input]
for info, value in zip(proto.input, inputs):
deserialize_value_info_proto(info, value)

# Add TensorAnnotation for inputs if they exist
if value.name in quantization_annotations:
_deserialize_quantization_annotation(quantization_annotations[value.name], value)

# Initialize the values dictionary for this graph scope with the inputs and initializers
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
scoped_values.append(values)
Expand All @@ -632,14 +645,21 @@ def _deserialize_graph(
type=_core.TensorType(tensor.dtype),
const_value=tensor,
)
if initializer_value.name in quantization_annotations:
_deserialize_quantization_annotation(
quantization_annotations[initializer_value.name], initializer_value
)
values[tensor.name] = initializer_value # type: ignore[index]
initializer_values.append(initializer_value)

# Add ValueInfos for this graph scope
value_info = {info.name: info for info in proto.value_info}

# Deserialize nodes with all known values
nodes = [_deserialize_node(node, scoped_values, value_info) for node in proto.node]
nodes = [
_deserialize_node(node, scoped_values, value_info, quantization_annotations)
for node in proto.node
]

# Fill in values for graph outputs
outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]
Expand All @@ -662,7 +682,10 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
value_info = {info.name: info for info in getattr(proto, "value_info", [])}

# TODO(justinchuby): Handle unsorted nodes
nodes = [_deserialize_node(node, [values], value_info=value_info) for node in proto.node]
nodes = [
_deserialize_node(node, [values], value_info=value_info, quantization_annotations={})
for node in proto.node
]
outputs = [values[name] for name in proto.output]
graph = _core.Graph(
inputs,
Expand Down Expand Up @@ -707,6 +730,19 @@ def deserialize_value_info_proto(
return value


@_capture_errors(lambda proto, value: str(proto))
def _deserialize_quantization_annotation(
proto: onnx.TensorAnnotation, value: _core.Value
) -> None:
"""Deserialize a quantization_annotation as TensorAnnotation into a Value.

This function is marked private because we don't expect users to call it directly.
"""
value.meta[_QUANT_PARAMETER_TENSOR_NAMES_FIELD] = _deserialize_string_string_maps(
proto.quant_parameter_tensor_names
)


@_capture_errors(str)
def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
# This logic handles when the shape is [] as well
Expand Down Expand Up @@ -844,6 +880,9 @@ def deserialize_metadata_props(
return {entry.key: entry.value for entry in proto}


_deserialize_string_string_maps = deserialize_metadata_props


def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefAttr:
return _deserialize_attribute(proto, [])

Expand Down Expand Up @@ -918,14 +957,17 @@ def _deserialize_attribute(


def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
return _deserialize_node(proto, scoped_values=[], value_info={})
return _deserialize_node(
proto, scoped_values=[], value_info={}, quantization_annotations={}
)


@_capture_errors(lambda proto, scoped_values, value_info: str(proto))
@_capture_errors(lambda proto, scoped_values, value_info, quantization_annotations: str(proto))
def _deserialize_node(
proto: onnx.NodeProto,
scoped_values: list[dict[str, _core.Value]],
value_info: dict[str, onnx.ValueInfoProto],
quantization_annotations: dict[str, onnx.TensorAnnotation],
) -> _core.Node:
node_inputs: list[_core.Value | None] = []
for input_name in proto.input:
Expand Down Expand Up @@ -968,6 +1010,10 @@ def _deserialize_node(
# Fill in shape/type information if they exist
if input_name in value_info:
deserialize_value_info_proto(value_info[input_name], value)
if input_name in quantization_annotations:
_deserialize_quantization_annotation(
quantization_annotations[input_name], value
)
node_inputs.append(value)
# We can only create the value in the current scope. If the subgraph is
# referencing a value that is not in the current scope, it is impossible
Expand Down Expand Up @@ -1009,6 +1055,8 @@ def _deserialize_node(
proto.name,
proto.op_type,
)
if output_name in quantization_annotations:
_deserialize_quantization_annotation(quantization_annotations[output_name], value)
node_outputs.append(value)
return _core.Node(
proto.domain,
Expand Down Expand Up @@ -1173,6 +1221,29 @@ def _serialize_metadata_props_into(
string_string_entries.add(key=key, value=from_[key])


_serialize_string_string_maps = _serialize_metadata_props_into


def _maybe_add_quantization_annotation(
graph_proto: onnx.GraphProto, value: _protocols.ValueProtocol
) -> None:
if quantization_annotation := value.meta.get(_QUANT_PARAMETER_TENSOR_NAMES_FIELD):
_serialize_tensor_annotation_into(
graph_proto.quantization_annotation.add(), value.name, quantization_annotation
)


def _serialize_tensor_annotation_into(
tensor_annotation_proto: onnx.TensorAnnotation,
tensor_name: str,
quant_parameter_tensor_names: dict[str, str],
) -> None:
tensor_annotation_proto.tensor_name = tensor_name
_serialize_string_string_maps(
tensor_annotation_proto.quant_parameter_tensor_names, quant_parameter_tensor_names
)


def serialize_graph(
graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
) -> onnx.GraphProto:
Expand Down Expand Up @@ -1208,8 +1279,14 @@ def serialize_graph_into(
graph_proto.doc_string = from_.doc_string
for input_ in from_.inputs:
serialize_value_into(graph_proto.input.add(), input_)
if input_.name not in from_.initializers:
# Annotations for initializers will be added below to avoid double adding
# TODO(justinchuby): We should add a method is_initializer() on Value when
# the initializer list is tracked
_maybe_add_quantization_annotation(graph_proto, input_)
# TODO(justinchuby): Support sparse_initializer
for initializer in from_.initializers.values():
_maybe_add_quantization_annotation(graph_proto, initializer)
if initializer.const_value is None:
# Skip initializers without constant values
logger.warning(
Expand All @@ -1222,15 +1299,18 @@ def serialize_graph_into(
for node in from_:
serialize_node_into(graph_proto.node.add(), from_=node)
for node_output in node.outputs:
if not _should_create_value_info_for_value(node_output):
# No need to serialize value info if it is not set
continue
if node_output.is_graph_output():
# No need to serialize value info for these outputs because they are also graph outputs
# No need to serialize info for these outputs because they are handled as graph outputs
continue
_maybe_add_quantization_annotation(graph_proto, node_output)
if not _should_create_value_info_for_value(node_output): # pylint: disable=no-else-continue
# No need to serialize value info if it is not set
continue
serialize_value_into(graph_proto.value_info.add(), node_output)
else:
serialize_value_into(graph_proto.value_info.add(), node_output)
for output in from_.outputs:
serialize_value_into(graph_proto.output.add(), from_=output)
_maybe_add_quantization_annotation(graph_proto, output)
if from_.metadata_props:
_serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props)

Expand Down
106 changes: 106 additions & 0 deletions onnxscript/ir/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.
import unittest

import google.protobuf.text_format
import ml_dtypes
import numpy as np
import onnx
Expand Down Expand Up @@ -290,5 +291,110 @@ def test_deserialize_graph_handles_unsorted_graph(self):
self.assertEqual(deserialized_graph[1].op_type, "Op_0")


class QuantizationAnnotationTest(unittest.TestCase):
"""Test that quantization annotations are correctly serialized and deserialized."""

def setUp(self):
model_text = """\
ir_version: 8
producer_name: "pytorch"
producer_version: "2.1.1"
graph {
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
node {
input: "input"
output: "intermediate_value"
op_type: "TestOp1"
domain: "test_domain"
}
node {
input: "intermediate_value"
output: "output"
op_type: "TestOp2"
domain: "test_domain"
}
quantization_annotation {
tensor_name: "input"
quant_parameter_tensor_names {
key: "custom_key"
value: "arbitrary_value_input"
}
}
quantization_annotation {
tensor_name: "intermediate_value"
quant_parameter_tensor_names {
key: "custom_key"
value: "arbitrary_value_intermediate"
}
}
quantization_annotation {
tensor_name: "output"
quant_parameter_tensor_names {
key: "custom_key"
value: "arbitrary_value_output"
}
}
}"""
self.model = onnx.ModelProto()
google.protobuf.text_format.Parse(model_text, self.model)

def test_deserialize_quantization_annotation(self):
model = serde.deserialize_model(self.model)
self.assertEqual(
model.graph.inputs[0].meta["quant_parameter_tensor_names"],
{"custom_key": "arbitrary_value_input"},
)
self.assertEqual(
model.graph.node(0).outputs[0].meta["quant_parameter_tensor_names"],
{"custom_key": "arbitrary_value_intermediate"},
)
self.assertEqual(
model.graph.outputs[0].meta["quant_parameter_tensor_names"],
{"custom_key": "arbitrary_value_output"},
)

def test_serde_roundtrip(self):
model = serde.deserialize_model(self.model)
serialized_model = serde.serialize_model(model)
deserialized_model = serde.deserialize_model(serialized_model)
self.assertEqual(
deserialized_model.graph.inputs[0].meta["quant_parameter_tensor_names"],
{"custom_key": "arbitrary_value_input"},
)
self.assertEqual(
deserialized_model.graph.node(0).outputs[0].meta["quant_parameter_tensor_names"],
{"custom_key": "arbitrary_value_intermediate"},
)
self.assertEqual(
deserialized_model.graph.outputs[0].meta["quant_parameter_tensor_names"],
{"custom_key": "arbitrary_value_output"},
)


if __name__ == "__main__":
unittest.main()
Loading