Skip to content

[Pass] Support lift constants to initializers pass #2160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,8 @@ def __init__(
# Be sure the initialize the name authority before extending the nodes
# because it is used to name the nodes and their outputs
self._name_authority = _name_authority.NameAuthority()
# TODO(justinchuby): Trigger again if inputs or initializers are modified.
self._set_input_and_initializer_value_names_into_name_authority()
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
self.extend(nodes)

Expand Down Expand Up @@ -1999,6 +2001,12 @@ def __iter__(self) -> Iterator[Node]:
def __reversed__(self) -> Iterator[Node]:
return reversed(self._nodes)

def _set_input_and_initializer_value_names_into_name_authority(self):
for value in self.inputs:
self._name_authority.register_or_name_value(value)
for value in self.initializers.values():
self._name_authority.register_or_name_value(value)

def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
if node.graph is not None and node.graph is not self:
Expand Down
95 changes: 95 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Lift constants to initializers."""

from __future__ import annotations

__all__ = [
"LiftConstantsToInitializersPass",
]

import logging

import numpy as np

from onnxscript import ir

logger = logging.getLogger(__name__)


class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Convert constant nodes from node belonged graph to its initializers."""
count = 0
for node in ir.traversal.RecursiveGraphIterator(model.graph):
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
continue

constant_node_attribute = set(node.attributes.keys())
if len(constant_node_attribute) != 1:
logger.debug(

Check warning on line 30 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L30

Added line #L30 was not covered by tests
"Invalid constant node '%s' has more than one attribute", node.name
)
continue

Check warning on line 33 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L33

Added line #L33 was not covered by tests

attr_name, attr_value = next(iter(node.attributes.items()))
initializer_name = node.outputs[0].name
assert initializer_name is not None
assert isinstance(attr_value, ir.Attr)
tensor = _constant_node_attribute_to_tensor(
attr_name, attr_value, initializer_name
)
if tensor is None:
logger.debug(

Check warning on line 43 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L43

Added line #L43 was not covered by tests
"Invalid constant node '%s' has unsupported attribute value", node.name
)
continue

Check warning on line 46 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L46

Added line #L46 was not covered by tests
# Register an initializer with the tensor value
initializer = ir.Value(
name=initializer_name,
shape=tensor.shape, # type: ignore[arg-type]
type=ir.TensorType(tensor.dtype),
const_value=tensor,
)
assert node.graph is not None
assert isinstance(node.graph, ir.Graph)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby does this make sense? I think there should not be any ir.Function node coming out from recursive iterator?

Copy link
Collaborator

@justinchuby justinchuby Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does make sense. Thanks! The reason why it was annotated with graph | function is that the “owning graph” can be a function when the node is part of a function. Maybe there are better ways to do it 🤔

Copy link
Collaborator

@gramalingam gramalingam Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Unrelated to this PR): Isn't a Function object a wrapper around a Graph object? Does node.graph not return that graph object even in the case of function nodes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically the graph in a function is private and not used directly. It is currently an implementation detail

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I was wrong. It is pointed to a graph when we call function.append, but it is not when we call ir.Node(graph=function). I need to figure out how to reconcile this. Suggestions appreciated. #2181

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

node.graph.register_initializer(initializer)
# Replace the constant node with the initilizer
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
node.graph.remove(node, safe=True)
count += 1
logger.debug(
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
)
if count:
logger.debug("Lifted %s constants to initializers", count)
return ir.passes.PassResult(model, modified=bool(count))


def _constant_node_attribute_to_tensor(
attr_name: str, attr_value: ir.Attr, initializer_name: str
) -> ir.Tensor | None:
"""Convert constant node attribute to tensor."""
if attr_name == "value":
tensor = attr_value.as_tensor() # type: ignore[union-attr]
elif attr_name == "value_int":
tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name)
elif attr_name == "value_ints":
tensor = ir.tensor(
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
)
elif attr_name == "value_float":
tensor = ir.tensor(
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name == "value_floats":
tensor = ir.tensor(
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name in ("value_string", "value_strings"):
tensor = ir.StringTensor(
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
)
else:
tensor = None

Check warning on line 94 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L94

Added line #L94 was not covered by tests
return tensor # type: ignore[return-value]
189 changes: 189 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np
import parameterized

from onnxscript import ir
from onnxscript.ir.passes.common import constant_manipulation


class TestLiftConstantsToInitializersPass(unittest.TestCase):
@parameterized.parameterized.expand(
[
(ir.DataType.FLOAT,),
(ir.DataType.INT64,),
]
)
def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtype):
inputs = [
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
ir.Value(
name="input_b",
type=ir.TensorType(ir_dtype),
shape=ir.Shape((2, 3)),
),
]

constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir_dtype.numpy()))
const_node = ir.node(
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1
)
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=mul_node.outputs,
nodes=[const_node, add_node, mul_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
self.assertEqual(len(model.graph.initializers), 0)
# And 1 constant node
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
# Check the value
self.assertEqual(
result.model.graph.initializers[
"val_0"
].const_value, # name created by name_authority
constant_tensor,
)
# And 0 constant node
self.assertEqual(
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
)

def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)

then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
then_const_node = ir.node(
"Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1
)
# then branch adds the constant to the input
# else branch multiplies the input by the constant
add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]])
then_graph = ir.Graph(
inputs=[input_value],
outputs=[add_node.outputs[0]],
nodes=[then_const_node, add_node],
opset_imports={"": 20},
)
else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32))
else_const_node = ir.node(
"Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1
)
mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]])
else_graph = ir.Graph(
inputs=[input_value],
outputs=[mul_node.outputs[0]],
nodes=[else_const_node, mul_node],
opset_imports={"": 20},
)
# create a conditional node that uses the then and else graphs
cond_node = ir.node(
"If",
inputs=[input_value],
attributes={"then_branch": then_graph, "else_branch": else_graph},
num_outputs=1,
)
# construnct the model
main_graph = ir.Graph(
inputs=[input_value],
outputs=cond_node.outputs,
nodes=[cond_node],
opset_imports={"": 20},
)
main_graph.sort()
model = ir.Model(
graph=main_graph,
ir_version=10,
)
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to the subgraph initializers
for node in ir.traversal.RecursiveGraphIterator(result.model.graph):
if node.op_type == "Constant":
raise AssertionError(

Check warning on line 123 in onnxscript/ir/passes/common/constant_manipulation_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation_test.py#L123

Added line #L123 was not covered by tests
f"Constant node '{node.name}' was not lifted to initializers"
)
self.assertEqual(len(else_graph.initializers), 1)
self.assertEqual(len(then_graph.initializers), 1)
self.assertIs(
else_graph.initializers["val_0"].const_value,
else_constant_tensor,
)
self.assertIs(
then_graph.initializers["val_0"].const_value,
then_constant_tensor,
)

@parameterized.parameterized.expand(
[
(1.0, "value_float", np.float32),
(1, "value_int", np.int64),
("hello world!", "value_string", np.bytes_),
([1.0, 2.0, 3.0], "value_floats", np.float32),
([1, 2, 3], "value_ints", np.int64),
(["hello world!", "thank you."], "value_strings", np.bytes_),
]
)
def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
self, value, constant_attribute, np_dtype
):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)

constant_value = value
const_node = ir.node(
"Constant",
inputs=[],
attributes={constant_attribute: constant_value},
num_outputs=1,
)
identity_node_constant = ir.node(
"Identity", inputs=[const_node.outputs[0]], num_outputs=1
)
identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1)

model = ir.Model(
graph=ir.Graph(
inputs=[input_value],
outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]],
nodes=[identity_node_input, const_node, identity_node_constant],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
self.assertEqual(len(model.graph.initializers), 0)
# And 1 constant node
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
np.testing.assert_array_equal(
result.model.graph.initializers["val_1"].const_value.numpy(),
np.array(constant_value, dtype=np_dtype),
)
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import onnxscript.ir.passes.common.constant_manipulation
import onnxscript.ir.passes.common.unused_removal
from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
Expand Down Expand Up @@ -52,6 +53,7 @@ def optimize_ir(
early_stop=stop_if_no_change,
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
)
assert optimizer_pass.in_place
result = optimizer_pass(model)
Expand Down
Loading