Skip to content

Commit e659cb4

Browse files
[Pass] Support lift constants to initializers pass (#2160)
Fix #2156 --------- Co-authored-by: Justin Chu <[email protected]>
1 parent 3a6e4cc commit e659cb4

File tree

4 files changed

+294
-0
lines changed

4 files changed

+294
-0
lines changed

onnxscript/ir/_core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,6 +1924,8 @@ def __init__(
19241924
# Be sure the initialize the name authority before extending the nodes
19251925
# because it is used to name the nodes and their outputs
19261926
self._name_authority = _name_authority.NameAuthority()
1927+
# TODO(justinchuby): Trigger again if inputs or initializers are modified.
1928+
self._set_input_and_initializer_value_names_into_name_authority()
19271929
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
19281930
self.extend(nodes)
19291931

@@ -1999,6 +2001,12 @@ def __iter__(self) -> Iterator[Node]:
19992001
def __reversed__(self) -> Iterator[Node]:
20002002
return reversed(self._nodes)
20012003

2004+
def _set_input_and_initializer_value_names_into_name_authority(self):
2005+
for value in self.inputs:
2006+
self._name_authority.register_or_name_value(value)
2007+
for value in self.initializers.values():
2008+
self._name_authority.register_or_name_value(value)
2009+
20022010
def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
20032011
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
20042012
if node.graph is not None and node.graph is not self:
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Lift constants to initializers."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"LiftConstantsToInitializersPass",
9+
]
10+
11+
import logging
12+
13+
import numpy as np
14+
15+
from onnxscript import ir
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
21+
def call(self, model: ir.Model) -> ir.passes.PassResult:
22+
"""Convert constant nodes from node belonged graph to its initializers."""
23+
count = 0
24+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
25+
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
26+
continue
27+
28+
constant_node_attribute = set(node.attributes.keys())
29+
if len(constant_node_attribute) != 1:
30+
logger.debug(
31+
"Invalid constant node '%s' has more than one attribute", node.name
32+
)
33+
continue
34+
35+
attr_name, attr_value = next(iter(node.attributes.items()))
36+
initializer_name = node.outputs[0].name
37+
assert initializer_name is not None
38+
assert isinstance(attr_value, ir.Attr)
39+
tensor = _constant_node_attribute_to_tensor(
40+
attr_name, attr_value, initializer_name
41+
)
42+
if tensor is None:
43+
logger.debug(
44+
"Invalid constant node '%s' has unsupported attribute value", node.name
45+
)
46+
continue
47+
# Register an initializer with the tensor value
48+
initializer = ir.Value(
49+
name=initializer_name,
50+
shape=tensor.shape, # type: ignore[arg-type]
51+
type=ir.TensorType(tensor.dtype),
52+
const_value=tensor,
53+
)
54+
assert node.graph is not None
55+
assert isinstance(node.graph, ir.Graph)
56+
node.graph.register_initializer(initializer)
57+
# Replace the constant node with the initilizer
58+
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
59+
node.graph.remove(node, safe=True)
60+
count += 1
61+
logger.debug(
62+
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
63+
)
64+
if count:
65+
logger.debug("Lifted %s constants to initializers", count)
66+
return ir.passes.PassResult(model, modified=bool(count))
67+
68+
69+
def _constant_node_attribute_to_tensor(
70+
attr_name: str, attr_value: ir.Attr, initializer_name: str
71+
) -> ir.Tensor | None:
72+
"""Convert constant node attribute to tensor."""
73+
if attr_name == "value":
74+
tensor = attr_value.as_tensor() # type: ignore[union-attr]
75+
elif attr_name == "value_int":
76+
tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name)
77+
elif attr_name == "value_ints":
78+
tensor = ir.tensor(
79+
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
80+
)
81+
elif attr_name == "value_float":
82+
tensor = ir.tensor(
83+
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
84+
)
85+
elif attr_name == "value_floats":
86+
tensor = ir.tensor(
87+
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
88+
)
89+
elif attr_name in ("value_string", "value_strings"):
90+
tensor = ir.StringTensor(
91+
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
92+
)
93+
else:
94+
tensor = None
95+
return tensor # type: ignore[return-value]
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import numpy as np
8+
import parameterized
9+
10+
from onnxscript import ir
11+
from onnxscript.ir.passes.common import constant_manipulation
12+
13+
14+
class TestLiftConstantsToInitializersPass(unittest.TestCase):
15+
@parameterized.parameterized.expand(
16+
[
17+
(ir.DataType.FLOAT,),
18+
(ir.DataType.INT64,),
19+
]
20+
)
21+
def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtype):
22+
inputs = [
23+
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
24+
ir.Value(
25+
name="input_b",
26+
type=ir.TensorType(ir_dtype),
27+
shape=ir.Shape((2, 3)),
28+
),
29+
]
30+
31+
constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir_dtype.numpy()))
32+
const_node = ir.node(
33+
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1
34+
)
35+
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]])
36+
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])
37+
38+
model = ir.Model(
39+
graph=ir.Graph(
40+
inputs=inputs,
41+
outputs=mul_node.outputs,
42+
nodes=[const_node, add_node, mul_node],
43+
opset_imports={"": 20},
44+
),
45+
ir_version=10,
46+
)
47+
48+
# Check that the initializer is not in the graph yet
49+
self.assertEqual(len(model.graph.initializers), 0)
50+
# And 1 constant node
51+
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)
52+
53+
# Perform lift constants to initializers
54+
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
55+
self.assertTrue(result.modified)
56+
# Check that the constant node is lifted to an initializer
57+
self.assertEqual(len(result.model.graph.initializers), 1)
58+
# Check the value
59+
self.assertEqual(
60+
result.model.graph.initializers[
61+
"val_0"
62+
].const_value, # name created by name_authority
63+
constant_tensor,
64+
)
65+
# And 0 constant node
66+
self.assertEqual(
67+
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
68+
)
69+
70+
def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
71+
input_value = ir.Value(
72+
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
73+
)
74+
75+
then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
76+
then_const_node = ir.node(
77+
"Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1
78+
)
79+
# then branch adds the constant to the input
80+
# else branch multiplies the input by the constant
81+
add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]])
82+
then_graph = ir.Graph(
83+
inputs=[input_value],
84+
outputs=[add_node.outputs[0]],
85+
nodes=[then_const_node, add_node],
86+
opset_imports={"": 20},
87+
)
88+
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
89+
else_const_node = ir.node(
90+
"Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1
91+
)
92+
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]])
93+
else_graph = ir.Graph(
94+
inputs=[input_value],
95+
outputs=[mul_node.outputs[0]],
96+
nodes=[else_const_node, mul_node],
97+
opset_imports={"": 20},
98+
)
99+
# create a conditional node that uses the then and else graphs
100+
cond_node = ir.node(
101+
"If",
102+
inputs=[input_value],
103+
attributes={"then_branch": then_graph, "else_branch": else_graph},
104+
num_outputs=1,
105+
)
106+
# construnct the model
107+
main_graph = ir.Graph(
108+
inputs=[input_value],
109+
outputs=cond_node.outputs,
110+
nodes=[cond_node],
111+
opset_imports={"": 20},
112+
)
113+
main_graph.sort()
114+
model = ir.Model(
115+
graph=main_graph,
116+
ir_version=10,
117+
)
118+
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
119+
self.assertTrue(result.modified)
120+
# Check that the constant node is lifted to the subgraph initializers
121+
for node in ir.traversal.RecursiveGraphIterator(result.model.graph):
122+
if node.op_type == "Constant":
123+
raise AssertionError(
124+
f"Constant node '{node.name}' was not lifted to initializers"
125+
)
126+
self.assertEqual(len(else_graph.initializers), 1)
127+
self.assertEqual(len(then_graph.initializers), 1)
128+
self.assertIs(
129+
else_graph.initializers["val_0"].const_value,
130+
else_constant_tensor,
131+
)
132+
self.assertIs(
133+
then_graph.initializers["val_0"].const_value,
134+
then_constant_tensor,
135+
)
136+
137+
@parameterized.parameterized.expand(
138+
[
139+
(1.0, "value_float", np.float32),
140+
(1, "value_int", np.int64),
141+
("hello world!", "value_string", np.bytes_),
142+
([1.0, 2.0, 3.0], "value_floats", np.float32),
143+
([1, 2, 3], "value_ints", np.int64),
144+
(["hello world!", "thank you."], "value_strings", np.bytes_),
145+
]
146+
)
147+
def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
148+
self, value, constant_attribute, np_dtype
149+
):
150+
input_value = ir.Value(
151+
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
152+
)
153+
154+
constant_value = value
155+
const_node = ir.node(
156+
"Constant",
157+
inputs=[],
158+
attributes={constant_attribute: constant_value},
159+
num_outputs=1,
160+
)
161+
identity_node_constant = ir.node(
162+
"Identity", inputs=[const_node.outputs[0]], num_outputs=1
163+
)
164+
identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1)
165+
166+
model = ir.Model(
167+
graph=ir.Graph(
168+
inputs=[input_value],
169+
outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]],
170+
nodes=[identity_node_input, const_node, identity_node_constant],
171+
opset_imports={"": 20},
172+
),
173+
ir_version=10,
174+
)
175+
176+
# Check that the initializer is not in the graph yet
177+
self.assertEqual(len(model.graph.initializers), 0)
178+
# And 1 constant node
179+
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)
180+
181+
# Perform lift constants to initializers
182+
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
183+
self.assertTrue(result.modified)
184+
# Check that the constant node is lifted to an initializer
185+
self.assertEqual(len(result.model.graph.initializers), 1)
186+
np.testing.assert_array_equal(
187+
result.model.graph.initializers["val_1"].const_value.numpy(),
188+
np.array(constant_value, dtype=np_dtype),
189+
)

onnxscript/optimizer/_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import logging
66

7+
import onnxscript.ir.passes.common.constant_manipulation
78
import onnxscript.ir.passes.common.unused_removal
89
from onnxscript import ir, rewriter
910
from onnxscript.optimizer import _constant_folding, _inliner
@@ -52,6 +53,7 @@ def optimize_ir(
5253
early_stop=stop_if_no_change,
5354
),
5455
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
56+
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
5557
)
5658
assert optimizer_pass.in_place
5759
result = optimizer_pass(model)

0 commit comments

Comments
 (0)