-
Notifications
You must be signed in to change notification settings - Fork 72
[Pass] Support lift constants to initializers pass #2160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c662194
c59af87
c0ea8e6
7db9be0
c007277
1956c7f
0ba7025
c830972
a285ceb
3b15f45
22df674
2c87912
82c4016
a2e9d2a
0a731e6
e648167
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Lift constants to initializers.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"LiftConstantsToInitializersPass", | ||
] | ||
|
||
import logging | ||
|
||
import numpy as np | ||
|
||
from onnxscript import ir | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LiftConstantsToInitializersPass(ir.passes.InPlacePass): | ||
def call(self, model: ir.Model) -> ir.passes.PassResult: | ||
"""Convert constant nodes from node belonged graph to its initializers.""" | ||
count = 0 | ||
for node in ir.traversal.RecursiveGraphIterator(model.graph): | ||
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): | ||
continue | ||
|
||
constant_node_attribute = set(node.attributes.keys()) | ||
if len(constant_node_attribute) != 1: | ||
logger.debug( | ||
"Invalid constant node '%s' has more than one attribute", node.name | ||
) | ||
continue | ||
|
||
attr_name, attr_value = next(iter(node.attributes.items())) | ||
initializer_name = node.outputs[0].name | ||
assert initializer_name is not None | ||
assert isinstance(attr_value, ir.Attr) | ||
tensor = _constant_node_attribute_to_tensor( | ||
attr_name, attr_value, initializer_name | ||
) | ||
if tensor is None: | ||
logger.debug( | ||
"Invalid constant node '%s' has unsupported attribute value", node.name | ||
) | ||
continue | ||
# Register an initializer with the tensor value | ||
initializer = ir.Value( | ||
name=initializer_name, | ||
shape=tensor.shape, # type: ignore[arg-type] | ||
type=ir.TensorType(tensor.dtype), | ||
const_value=tensor, | ||
) | ||
assert node.graph is not None | ||
assert isinstance(node.graph, ir.Graph) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @justinchuby does this make sense? I think there should not be any ir.Function node coming out from recursive iterator? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does make sense. Thanks! The reason why it was annotated with graph | function is that the “owning graph” can be a function when the node is part of a function. Maybe there are better ways to do it 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Unrelated to this PR): Isn't a Function object a wrapper around a Graph object? Does node.graph not return that graph object even in the case of function nodes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically the graph in a function is private and not used directly. It is currently an implementation detail There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry I was wrong. It is pointed to a graph when we call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
node.graph.register_initializer(initializer) | ||
# Replace the constant node with the initilizer | ||
ir.convenience.replace_all_uses_with(node.outputs[0], initializer) | ||
node.graph.remove(node, safe=True) | ||
count += 1 | ||
logger.debug( | ||
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name | ||
) | ||
if count: | ||
logger.debug("Lifted %s constants to initializers", count) | ||
return ir.passes.PassResult(model, modified=bool(count)) | ||
|
||
|
||
def _constant_node_attribute_to_tensor( | ||
attr_name: str, attr_value: ir.Attr, initializer_name: str | ||
) -> ir.Tensor | None: | ||
"""Convert constant node attribute to tensor.""" | ||
if attr_name == "value": | ||
tensor = attr_value.as_tensor() # type: ignore[union-attr] | ||
elif attr_name == "value_int": | ||
tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name) | ||
elif attr_name == "value_ints": | ||
tensor = ir.tensor( | ||
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name | ||
) | ||
elif attr_name == "value_float": | ||
tensor = ir.tensor( | ||
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name | ||
) | ||
elif attr_name == "value_floats": | ||
tensor = ir.tensor( | ||
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name | ||
) | ||
elif attr_name in ("value_string", "value_strings"): | ||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tensor = ir.StringTensor( | ||
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name | ||
) | ||
else: | ||
tensor = None | ||
return tensor # type: ignore[return-value] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
|
||
import unittest | ||
|
||
import numpy as np | ||
import parameterized | ||
|
||
from onnxscript import ir | ||
from onnxscript.ir.passes.common import constant_manipulation | ||
|
||
|
||
class TestLiftConstantsToInitializersPass(unittest.TestCase): | ||
@parameterized.parameterized.expand( | ||
[ | ||
(ir.DataType.FLOAT,), | ||
(ir.DataType.INT64,), | ||
] | ||
) | ||
def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtype): | ||
inputs = [ | ||
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))), | ||
ir.Value( | ||
name="input_b", | ||
type=ir.TensorType(ir_dtype), | ||
shape=ir.Shape((2, 3)), | ||
), | ||
] | ||
|
||
constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir_dtype.numpy())) | ||
const_node = ir.node( | ||
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 | ||
) | ||
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]]) | ||
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) | ||
|
||
model = ir.Model( | ||
graph=ir.Graph( | ||
inputs=inputs, | ||
outputs=mul_node.outputs, | ||
nodes=[const_node, add_node, mul_node], | ||
opset_imports={"": 20}, | ||
), | ||
ir_version=10, | ||
) | ||
|
||
# Check that the initializer is not in the graph yet | ||
self.assertEqual(len(model.graph.initializers), 0) | ||
# And 1 constant node | ||
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) | ||
|
||
# Perform lift constants to initializers | ||
result = constant_manipulation.LiftConstantsToInitializersPass()(model) | ||
self.assertTrue(result.modified) | ||
# Check that the constant node is lifted to an initializer | ||
self.assertEqual(len(result.model.graph.initializers), 1) | ||
# Check the value | ||
self.assertEqual( | ||
result.model.graph.initializers[ | ||
"val_0" | ||
].const_value, # name created by name_authority | ||
constant_tensor, | ||
) | ||
# And 0 constant node | ||
self.assertEqual( | ||
len([node for node in result.model.graph if node.op_type == "Constant"]), 0 | ||
) | ||
|
||
def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): | ||
input_value = ir.Value( | ||
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) | ||
) | ||
|
||
then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) | ||
then_const_node = ir.node( | ||
"Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1 | ||
) | ||
# then branch adds the constant to the input | ||
# else branch multiplies the input by the constant | ||
add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]]) | ||
then_graph = ir.Graph( | ||
inputs=[input_value], | ||
outputs=[add_node.outputs[0]], | ||
nodes=[then_const_node, add_node], | ||
opset_imports={"": 20}, | ||
) | ||
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) | ||
else_const_node = ir.node( | ||
"Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1 | ||
) | ||
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]]) | ||
else_graph = ir.Graph( | ||
inputs=[input_value], | ||
outputs=[mul_node.outputs[0]], | ||
nodes=[else_const_node, mul_node], | ||
opset_imports={"": 20}, | ||
) | ||
# create a conditional node that uses the then and else graphs | ||
cond_node = ir.node( | ||
"If", | ||
inputs=[input_value], | ||
attributes={"then_branch": then_graph, "else_branch": else_graph}, | ||
num_outputs=1, | ||
) | ||
# construnct the model | ||
main_graph = ir.Graph( | ||
inputs=[input_value], | ||
outputs=cond_node.outputs, | ||
nodes=[cond_node], | ||
opset_imports={"": 20}, | ||
) | ||
main_graph.sort() | ||
model = ir.Model( | ||
graph=main_graph, | ||
ir_version=10, | ||
) | ||
result = constant_manipulation.LiftConstantsToInitializersPass()(model) | ||
self.assertTrue(result.modified) | ||
# Check that the constant node is lifted to the subgraph initializers | ||
for node in ir.traversal.RecursiveGraphIterator(result.model.graph): | ||
if node.op_type == "Constant": | ||
raise AssertionError( | ||
f"Constant node '{node.name}' was not lifted to initializers" | ||
) | ||
self.assertEqual(len(else_graph.initializers), 1) | ||
self.assertEqual(len(then_graph.initializers), 1) | ||
self.assertIs( | ||
else_graph.initializers["val_0"].const_value, | ||
else_constant_tensor, | ||
) | ||
self.assertIs( | ||
then_graph.initializers["val_0"].const_value, | ||
then_constant_tensor, | ||
) | ||
|
||
@parameterized.parameterized.expand( | ||
[ | ||
(1.0, "value_float", np.float32), | ||
(1, "value_int", np.int64), | ||
("hello world!", "value_string", np.bytes_), | ||
([1.0, 2.0, 3.0], "value_floats", np.float32), | ||
([1, 2, 3], "value_ints", np.int64), | ||
(["hello world!", "thank you."], "value_strings", np.bytes_), | ||
] | ||
) | ||
def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( | ||
self, value, constant_attribute, np_dtype | ||
): | ||
input_value = ir.Value( | ||
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) | ||
) | ||
|
||
constant_value = value | ||
const_node = ir.node( | ||
"Constant", | ||
inputs=[], | ||
attributes={constant_attribute: constant_value}, | ||
num_outputs=1, | ||
) | ||
identity_node_constant = ir.node( | ||
"Identity", inputs=[const_node.outputs[0]], num_outputs=1 | ||
) | ||
identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1) | ||
|
||
model = ir.Model( | ||
graph=ir.Graph( | ||
inputs=[input_value], | ||
outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]], | ||
nodes=[identity_node_input, const_node, identity_node_constant], | ||
opset_imports={"": 20}, | ||
), | ||
ir_version=10, | ||
) | ||
|
||
# Check that the initializer is not in the graph yet | ||
self.assertEqual(len(model.graph.initializers), 0) | ||
# And 1 constant node | ||
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) | ||
|
||
# Perform lift constants to initializers | ||
result = constant_manipulation.LiftConstantsToInitializersPass()(model) | ||
self.assertTrue(result.modified) | ||
# Check that the constant node is lifted to an initializer | ||
self.assertEqual(len(result.model.graph.initializers), 1) | ||
np.testing.assert_array_equal( | ||
result.model.graph.initializers["val_1"].const_value.numpy(), | ||
np.array(constant_value, dtype=np_dtype), | ||
) |
Uh oh!
There was an error while loading. Please reload this page.