Skip to content

Commit ddce766

Browse files
authored
[IR] Add support for quant_parameter_tensor_names field (#2080)
Support quantization_annotation in graph inputs, node in/out and graph outputs. Two design decisions made are 1. Make ir.Value carry the `quant_parameter_tensor_names` information. This is similar to ValueInfoProto where in proto we store a list of proto messages whose keys point tensor names. But the information really belongs to individual values. 2. ``quantization_annotation`` is deserialized into the Value's ``meta`` field under the ``quant_parameter_tensor_names`` key. Values that are stored under this key will be serialized as quantization annotations. I chose to add a value in `meta` instead of creating a new property in value to avoid over complicating the preperties in Value. ## Example usage ```python >>> from onnxscript import ir >>> model = ir.load("l_1_n_12_z_384_i_1536.onnx") >>> model.graph.node("MVAU_rtl_0").outputs[0] Value('MVAU_rtl_0_out0', type=Tensor(FLOAT), shape=[1,128,384], producer=MVAU_rtl_0, index=0) >>> model.graph.node("MVAU_rtl_0").outputs[0].meta MetadataStore({'quant_parameter_tensor_names': {'finn_datatype': 'INT22'}}, invalid_keys=set()) >>> ir.save(model, "model_with_quant_params.textproto") ```
1 parent 8e53070 commit ddce766

File tree

3 files changed

+200
-10
lines changed

3 files changed

+200
-10
lines changed

onnxscript/ir/_protocols.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,11 @@ class GraphProtocol(Protocol):
277277
seen as a Sequence of nodes and should be used as such. For example, to obtain
278278
all nodes as a list, call ``list(graph)``.
279279
280+
.. :note::
281+
``quantization_annotation`` is deserialized into the Value's ``meta`` field
282+
under the ``quant_parameter_tensor_names`` key. Values that are stored
283+
under this key will be serialized as quantization annotations.
284+
280285
Attributes:
281286
name: The name of the graph.
282287
inputs: The input values of the graph.
@@ -288,7 +293,6 @@ class GraphProtocol(Protocol):
288293
meta: Metadata store for graph transform passes.
289294
"""
290295

291-
# TODO(justinchuby): Support quantization_annotation
292296
name: str | None
293297
inputs: MutableSequence[ValueProtocol]
294298
outputs: MutableSequence[ValueProtocol]

onnxscript/ir/serde.py

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
_FUNCTION_VALUE_INFO_SUPPORTED_VERSION = (
8282
10 # ONNX IR version where value info in functions was introduced
8383
)
84+
_QUANT_PARAMETER_TENSOR_NAMES_FIELD = "quant_parameter_tensor_names"
8485
_T = typing.TypeVar("_T", bound=Callable[..., Any])
8586

8687

@@ -586,6 +587,9 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
586587
587588
Returns:
588589
IR Graph.
590+
591+
.. versionadded:: 0.3
592+
Support for *quantization_annotation* is added.
589593
"""
590594
return _deserialize_graph(proto, [])
591595

@@ -606,12 +610,21 @@ def _deserialize_graph(
606610
Returns:
607611
IR Graph.
608612
"""
613+
# Process TensorAnnotation for quantization
614+
quantization_annotations = {
615+
annotation.tensor_name: annotation for annotation in proto.quantization_annotation
616+
}
617+
609618
# Create values for initializers and inputs
610619
initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
611620
inputs = [_core.Input(info.name) for info in proto.input]
612621
for info, value in zip(proto.input, inputs):
613622
deserialize_value_info_proto(info, value)
614623

624+
# Add TensorAnnotation for inputs if they exist
625+
if value.name in quantization_annotations:
626+
_deserialize_quantization_annotation(quantization_annotations[value.name], value)
627+
615628
# Initialize the values dictionary for this graph scope with the inputs and initializers
616629
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
617630
scoped_values.append(values)
@@ -632,14 +645,21 @@ def _deserialize_graph(
632645
type=_core.TensorType(tensor.dtype),
633646
const_value=tensor,
634647
)
648+
if initializer_value.name in quantization_annotations:
649+
_deserialize_quantization_annotation(
650+
quantization_annotations[initializer_value.name], initializer_value
651+
)
635652
values[tensor.name] = initializer_value # type: ignore[index]
636653
initializer_values.append(initializer_value)
637654

638655
# Add ValueInfos for this graph scope
639656
value_info = {info.name: info for info in proto.value_info}
640657

641658
# Deserialize nodes with all known values
642-
nodes = [_deserialize_node(node, scoped_values, value_info) for node in proto.node]
659+
nodes = [
660+
_deserialize_node(node, scoped_values, value_info, quantization_annotations)
661+
for node in proto.node
662+
]
643663

644664
# Fill in values for graph outputs
645665
outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]
@@ -662,7 +682,10 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
662682
value_info = {info.name: info for info in getattr(proto, "value_info", [])}
663683

664684
# TODO(justinchuby): Handle unsorted nodes
665-
nodes = [_deserialize_node(node, [values], value_info=value_info) for node in proto.node]
685+
nodes = [
686+
_deserialize_node(node, [values], value_info=value_info, quantization_annotations={})
687+
for node in proto.node
688+
]
666689
outputs = [values[name] for name in proto.output]
667690
graph = _core.Graph(
668691
inputs,
@@ -707,6 +730,19 @@ def deserialize_value_info_proto(
707730
return value
708731

709732

733+
@_capture_errors(lambda proto, value: str(proto))
734+
def _deserialize_quantization_annotation(
735+
proto: onnx.TensorAnnotation, value: _core.Value
736+
) -> None:
737+
"""Deserialize a quantization_annotation as TensorAnnotation into a Value.
738+
739+
This function is marked private because we don't expect users to call it directly.
740+
"""
741+
value.meta[_QUANT_PARAMETER_TENSOR_NAMES_FIELD] = _deserialize_string_string_maps(
742+
proto.quant_parameter_tensor_names
743+
)
744+
745+
710746
@_capture_errors(str)
711747
def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
712748
# This logic handles when the shape is [] as well
@@ -844,6 +880,9 @@ def deserialize_metadata_props(
844880
return {entry.key: entry.value for entry in proto}
845881

846882

883+
_deserialize_string_string_maps = deserialize_metadata_props
884+
885+
847886
def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefAttr:
848887
return _deserialize_attribute(proto, [])
849888

@@ -918,14 +957,17 @@ def _deserialize_attribute(
918957

919958

920959
def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
921-
return _deserialize_node(proto, scoped_values=[], value_info={})
960+
return _deserialize_node(
961+
proto, scoped_values=[], value_info={}, quantization_annotations={}
962+
)
922963

923964

924-
@_capture_errors(lambda proto, scoped_values, value_info: str(proto))
965+
@_capture_errors(lambda proto, scoped_values, value_info, quantization_annotations: str(proto))
925966
def _deserialize_node(
926967
proto: onnx.NodeProto,
927968
scoped_values: list[dict[str, _core.Value]],
928969
value_info: dict[str, onnx.ValueInfoProto],
970+
quantization_annotations: dict[str, onnx.TensorAnnotation],
929971
) -> _core.Node:
930972
node_inputs: list[_core.Value | None] = []
931973
for input_name in proto.input:
@@ -968,6 +1010,10 @@ def _deserialize_node(
9681010
# Fill in shape/type information if they exist
9691011
if input_name in value_info:
9701012
deserialize_value_info_proto(value_info[input_name], value)
1013+
if input_name in quantization_annotations:
1014+
_deserialize_quantization_annotation(
1015+
quantization_annotations[input_name], value
1016+
)
9711017
node_inputs.append(value)
9721018
# We can only create the value in the current scope. If the subgraph is
9731019
# referencing a value that is not in the current scope, it is impossible
@@ -1009,6 +1055,8 @@ def _deserialize_node(
10091055
proto.name,
10101056
proto.op_type,
10111057
)
1058+
if output_name in quantization_annotations:
1059+
_deserialize_quantization_annotation(quantization_annotations[output_name], value)
10121060
node_outputs.append(value)
10131061
return _core.Node(
10141062
proto.domain,
@@ -1173,6 +1221,29 @@ def _serialize_metadata_props_into(
11731221
string_string_entries.add(key=key, value=from_[key])
11741222

11751223

1224+
_serialize_string_string_maps = _serialize_metadata_props_into
1225+
1226+
1227+
def _maybe_add_quantization_annotation(
1228+
graph_proto: onnx.GraphProto, value: _protocols.ValueProtocol
1229+
) -> None:
1230+
if quantization_annotation := value.meta.get(_QUANT_PARAMETER_TENSOR_NAMES_FIELD):
1231+
_serialize_tensor_annotation_into(
1232+
graph_proto.quantization_annotation.add(), value.name, quantization_annotation
1233+
)
1234+
1235+
1236+
def _serialize_tensor_annotation_into(
1237+
tensor_annotation_proto: onnx.TensorAnnotation,
1238+
tensor_name: str,
1239+
quant_parameter_tensor_names: dict[str, str],
1240+
) -> None:
1241+
tensor_annotation_proto.tensor_name = tensor_name
1242+
_serialize_string_string_maps(
1243+
tensor_annotation_proto.quant_parameter_tensor_names, quant_parameter_tensor_names
1244+
)
1245+
1246+
11761247
def serialize_graph(
11771248
graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
11781249
) -> onnx.GraphProto:
@@ -1208,8 +1279,14 @@ def serialize_graph_into(
12081279
graph_proto.doc_string = from_.doc_string
12091280
for input_ in from_.inputs:
12101281
serialize_value_into(graph_proto.input.add(), input_)
1282+
if input_.name not in from_.initializers:
1283+
# Annotations for initializers will be added below to avoid double adding
1284+
# TODO(justinchuby): We should add a method is_initializer() on Value when
1285+
# the initializer list is tracked
1286+
_maybe_add_quantization_annotation(graph_proto, input_)
12111287
# TODO(justinchuby): Support sparse_initializer
12121288
for initializer in from_.initializers.values():
1289+
_maybe_add_quantization_annotation(graph_proto, initializer)
12131290
if initializer.const_value is None:
12141291
# Skip initializers without constant values
12151292
logger.warning(
@@ -1222,15 +1299,18 @@ def serialize_graph_into(
12221299
for node in from_:
12231300
serialize_node_into(graph_proto.node.add(), from_=node)
12241301
for node_output in node.outputs:
1225-
if not _should_create_value_info_for_value(node_output):
1226-
# No need to serialize value info if it is not set
1227-
continue
12281302
if node_output.is_graph_output():
1229-
# No need to serialize value info for these outputs because they are also graph outputs
1303+
# No need to serialize info for these outputs because they are handled as graph outputs
1304+
continue
1305+
_maybe_add_quantization_annotation(graph_proto, node_output)
1306+
if not _should_create_value_info_for_value(node_output): # pylint: disable=no-else-continue
1307+
# No need to serialize value info if it is not set
12301308
continue
1231-
serialize_value_into(graph_proto.value_info.add(), node_output)
1309+
else:
1310+
serialize_value_into(graph_proto.value_info.add(), node_output)
12321311
for output in from_.outputs:
12331312
serialize_value_into(graph_proto.output.add(), from_=output)
1313+
_maybe_add_quantization_annotation(graph_proto, output)
12341314
if from_.metadata_props:
12351315
_serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props)
12361316

onnxscript/ir/serde_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Licensed under the MIT License.
33
import unittest
44

5+
import google.protobuf.text_format
56
import ml_dtypes
67
import numpy as np
78
import onnx
@@ -290,5 +291,110 @@ def test_deserialize_graph_handles_unsorted_graph(self):
290291
self.assertEqual(deserialized_graph[1].op_type, "Op_0")
291292

292293

294+
class QuantizationAnnotationTest(unittest.TestCase):
295+
"""Test that quantization annotations are correctly serialized and deserialized."""
296+
297+
def setUp(self):
298+
model_text = """\
299+
ir_version: 8
300+
producer_name: "pytorch"
301+
producer_version: "2.1.1"
302+
graph {
303+
input {
304+
name: "input"
305+
type {
306+
tensor_type {
307+
elem_type: 1
308+
shape {
309+
dim {
310+
dim_value: 1
311+
}
312+
}
313+
}
314+
}
315+
}
316+
output {
317+
name: "output"
318+
type {
319+
tensor_type {
320+
elem_type: 1
321+
shape {
322+
dim {
323+
dim_value: 1
324+
}
325+
}
326+
}
327+
}
328+
}
329+
node {
330+
input: "input"
331+
output: "intermediate_value"
332+
op_type: "TestOp1"
333+
domain: "test_domain"
334+
}
335+
node {
336+
input: "intermediate_value"
337+
output: "output"
338+
op_type: "TestOp2"
339+
domain: "test_domain"
340+
}
341+
quantization_annotation {
342+
tensor_name: "input"
343+
quant_parameter_tensor_names {
344+
key: "custom_key"
345+
value: "arbitrary_value_input"
346+
}
347+
}
348+
quantization_annotation {
349+
tensor_name: "intermediate_value"
350+
quant_parameter_tensor_names {
351+
key: "custom_key"
352+
value: "arbitrary_value_intermediate"
353+
}
354+
}
355+
quantization_annotation {
356+
tensor_name: "output"
357+
quant_parameter_tensor_names {
358+
key: "custom_key"
359+
value: "arbitrary_value_output"
360+
}
361+
}
362+
}"""
363+
self.model = onnx.ModelProto()
364+
google.protobuf.text_format.Parse(model_text, self.model)
365+
366+
def test_deserialize_quantization_annotation(self):
367+
model = serde.deserialize_model(self.model)
368+
self.assertEqual(
369+
model.graph.inputs[0].meta["quant_parameter_tensor_names"],
370+
{"custom_key": "arbitrary_value_input"},
371+
)
372+
self.assertEqual(
373+
model.graph.node(0).outputs[0].meta["quant_parameter_tensor_names"],
374+
{"custom_key": "arbitrary_value_intermediate"},
375+
)
376+
self.assertEqual(
377+
model.graph.outputs[0].meta["quant_parameter_tensor_names"],
378+
{"custom_key": "arbitrary_value_output"},
379+
)
380+
381+
def test_serde_roundtrip(self):
382+
model = serde.deserialize_model(self.model)
383+
serialized_model = serde.serialize_model(model)
384+
deserialized_model = serde.deserialize_model(serialized_model)
385+
self.assertEqual(
386+
deserialized_model.graph.inputs[0].meta["quant_parameter_tensor_names"],
387+
{"custom_key": "arbitrary_value_input"},
388+
)
389+
self.assertEqual(
390+
deserialized_model.graph.node(0).outputs[0].meta["quant_parameter_tensor_names"],
391+
{"custom_key": "arbitrary_value_intermediate"},
392+
)
393+
self.assertEqual(
394+
deserialized_model.graph.outputs[0].meta["quant_parameter_tensor_names"],
395+
{"custom_key": "arbitrary_value_output"},
396+
)
397+
398+
293399
if __name__ == "__main__":
294400
unittest.main()

0 commit comments

Comments
 (0)