Skip to content

Commit d1a8215

Browse files
[Pass] Remove metadata_props and doc_string from the model (#2182)
Fix #2163 --------- Co-authored-by: Justin Chu <[email protected]>
1 parent ef8e889 commit d1a8215

File tree

4 files changed

+157
-13
lines changed

4 files changed

+157
-13
lines changed

onnxscript/ir/_core.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,16 +2583,14 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
25832583
outputs: The output values of the function.
25842584
opset_imports: Opsets imported by the function.
25852585
doc_string: Documentation string.
2586-
metadata_props: Metadata that will be serialized to the ONNX file.
25872586
meta: Metadata store for graph transform passes.
2587+
metadata_props: Metadata that will be serialized to the ONNX file.
25882588
"""
25892589

25902590
__slots__ = (
25912591
"_attributes",
25922592
"_domain",
25932593
"_graph",
2594-
"_metadata",
2595-
"_metadata_props",
25962594
"_name",
25972595
"_overload",
25982596
)
@@ -2607,15 +2605,12 @@ def __init__(
26072605
# and not from an outer scope
26082606
graph: Graph,
26092607
attributes: Sequence[Attr],
2610-
metadata_props: dict[str, str] | None = None,
26112608
) -> None:
26122609
self._domain = domain
26132610
self._name = name
26142611
self._overload = overload
26152612
self._graph = graph
26162613
self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
2617-
self._metadata: _metadata.MetadataStore | None = None
2618-
self._metadata_props: dict[str, str] | None = metadata_props
26192614

26202615
def identifier(self) -> _protocols.OperatorIdentifier:
26212616
return self.domain, self.name, self.overload
@@ -2687,15 +2682,11 @@ def meta(self) -> _metadata.MetadataStore:
26872682
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
26882683
to the ONNX proto.
26892684
"""
2690-
if self._metadata is None:
2691-
self._metadata = _metadata.MetadataStore()
2692-
return self._metadata
2685+
return self._graph.meta
26932686

26942687
@property
26952688
def metadata_props(self) -> dict[str, str]:
2696-
if self._metadata_props is None:
2697-
self._metadata_props = {}
2698-
return self._metadata_props
2689+
return self._graph.metadata_props
26992690

27002691
# Mutation methods
27012692
def append(self, node: Node, /) -> None:
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Clear all metadata and docstring from the model, graphs, nodes, and functions."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"ClearMetadataAndDocStringPass",
9+
]
10+
11+
import logging
12+
13+
from onnxscript import ir
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class ClearMetadataAndDocStringPass(ir.passes.InPlacePass):
19+
def call(self, model: ir.Model) -> ir.passes.PassResult:
20+
# 0. TODO: Should we clean model metadata and docstring?
21+
22+
# 1. Clean up the graph and the belonged nodes metadata properties
23+
modified = self._clear_graph_or_function_metadata_and_docstring(model.graph)
24+
25+
# 2. Clean up all of the functions metadata properties
26+
for function in model.functions.values():
27+
modified = (
28+
self._clear_graph_or_function_metadata_and_docstring(function) or modified
29+
)
30+
return ir.passes.PassResult(model, modified=modified)
31+
32+
def _clear_graph_or_function_metadata_and_docstring(
33+
self,
34+
graph_or_function: ir.Graph | ir.Function,
35+
) -> bool:
36+
"""Clear metadata and docstring from the graph or function."""
37+
checked_graphs_or_functions: set[ir.Graph | ir.Function] = set()
38+
modified = False
39+
# Clean up all of the nodes metadata properties
40+
for node in ir.traversal.RecursiveGraphIterator(graph_or_function):
41+
if node.metadata_props:
42+
modified = True
43+
logger.debug("Removed metadata from %s nodes", node.name)
44+
node.metadata_props.clear()
45+
node.doc_string = None
46+
47+
# Clean up the owning graph/function metadata properties
48+
# and doc_string if the graph/function is not already checked
49+
assert node.graph is not None
50+
if node.graph not in checked_graphs_or_functions and (
51+
node.graph.metadata_props or node.graph.doc_string
52+
):
53+
modified = True
54+
logger.debug("Removed metadata from %s graph/function", node.graph.name)
55+
node.graph.metadata_props.clear()
56+
node.graph.doc_string = None
57+
checked_graphs_or_functions.add(node.graph)
58+
return modified
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import numpy as np
8+
9+
from onnxscript import ir
10+
from onnxscript.ir.passes.common import clear_metadata_and_docstring
11+
12+
13+
class TestClearMetadataAndDocStringPass(unittest.TestCase):
14+
def test_pass_with_clear_metadata_and_docstring(self):
15+
# Create a model (node, graph, function) with metadata and docstring
16+
inputs = [
17+
ir.Value(
18+
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
19+
),
20+
ir.Value(
21+
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
22+
),
23+
]
24+
add_node = ir.node(
25+
"Add",
26+
inputs=inputs,
27+
num_outputs=1,
28+
metadata_props={"add_key": "add_value"},
29+
doc_string="This is an Add node",
30+
)
31+
mul_node = ir.node(
32+
"Mul",
33+
inputs=[add_node.outputs[0], inputs[1]],
34+
num_outputs=1,
35+
metadata_props={"mul_key": "mul_value"},
36+
doc_string="This is a Mul node",
37+
)
38+
function = ir.Function(
39+
graph=ir.Graph(
40+
name="my_function",
41+
inputs=inputs,
42+
outputs=mul_node.outputs,
43+
nodes=[add_node, mul_node],
44+
opset_imports={"": 20},
45+
doc_string="This is a function docstring",
46+
metadata_props={"function_key": "function_value"},
47+
),
48+
name="my_function",
49+
domain="my_domain",
50+
attributes=[],
51+
)
52+
# Create a model with the graph and function
53+
constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir.DataType.FLOAT.numpy()))
54+
const_node = ir.node(
55+
"Constant",
56+
inputs=[],
57+
attributes={"value": constant_tensor},
58+
num_outputs=1,
59+
metadata_props={"const_key": "const_value"},
60+
doc_string="This is a Constant node",
61+
)
62+
sub_node = ir.node(
63+
"Sub",
64+
inputs=[function.outputs[0], const_node.outputs[0]],
65+
num_outputs=1,
66+
metadata_props={"sub_key": "sub_value"},
67+
doc_string="This is a Sub node",
68+
)
69+
model = ir.Model(
70+
graph=ir.Graph(
71+
inputs=inputs,
72+
outputs=sub_node.outputs,
73+
nodes=[const_node, sub_node],
74+
opset_imports={"": 20},
75+
doc_string="This is a graph docstring",
76+
metadata_props={"graph_key": "graph_value"},
77+
),
78+
ir_version=10,
79+
functions=[function],
80+
)
81+
# Create a pass to clear metadata and docstring
82+
clear_pass = clear_metadata_and_docstring.ClearMetadataAndDocStringPass()
83+
# Apply the pass
84+
result = clear_pass(model)
85+
# Check that the pass was applied
86+
self.assertTrue(result.modified)
87+
# Check that the metadata and docstring were cleared
88+
self.assertEqual(model.graph.doc_string, None)
89+
self.assertEqual(model.graph.metadata_props, {})
90+
for node in model.graph:
91+
self.assertEqual(node.metadata_props, {})
92+
self.assertEqual(node.doc_string, None)
93+
# Check that the function docstring and metadata were cleared
94+
self.assertEqual(function.doc_string, None)
95+
self.assertEqual(function.metadata_props, {})

onnxscript/ir/serde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
699699
if hasattr(proto, "overload") and proto.overload
700700
else ""
701701
),
702+
metadata_props=deserialize_metadata_props(proto.metadata_props),
702703
)
703704
attributes = [_deserialize_attribute(attr, []) for attr in proto.attribute_proto]
704705
# Attributes without defaults
@@ -711,7 +712,6 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
711712
overload=getattr(proto, "overload", ""),
712713
graph=graph,
713714
attributes=typing.cast(List[_core.Attr], attributes),
714-
metadata_props=deserialize_metadata_props(proto.metadata_props),
715715
)
716716

717717

0 commit comments

Comments
 (0)