Skip to content

[Pass] Support lift constants to initializers pass #2160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,8 @@ def __init__(
# Be sure the initialize the name authority before extending the nodes
# because it is used to name the nodes and their outputs
self._name_authority = _name_authority.NameAuthority()
# TODO(justinchuby): Trigger again if inputs or initializers are modified.
self._set_input_and_initializer_value_names_into_name_authority()
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
self.extend(nodes)

Expand Down Expand Up @@ -1999,6 +2001,12 @@ def __iter__(self) -> Iterator[Node]:
def __reversed__(self) -> Iterator[Node]:
return reversed(self._nodes)

def _set_input_and_initializer_value_names_into_name_authority(self):
for value in self.inputs:
self._name_authority.register_or_name_value(value)
for value in self.initializers.values():
self._name_authority.register_or_name_value(value)

def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
if node.graph is not None and node.graph is not self:
Expand Down
101 changes: 101 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Lift constants to initializers."""

from __future__ import annotations

__all__ = [
"LiftConstantsToInitializersPass",
]

import logging

import numpy as np

from onnxscript import ir

logger = logging.getLogger(__name__)


class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Convert constant nodes in main graph to initializers."""
count = 0
for node in model.graph:
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
continue

constant_node_attribute = set(node.attributes.keys())
if len(constant_node_attribute) != 1:
logger.debug(
"Invalid constant node '%s' has more than one attribute", node.name
)
continue

attr_name, attr_value = next(iter(node.attributes.items()))
initializer_name = node.outputs[0].name
assert initializer_name is not None
assert isinstance(attr_value, ir.Attr)
tensor = _constant_node_attribute_to_tensor(
attr_name, attr_value, initializer_name
)
if tensor is None:
logger.debug(
"Invalid constant node '%s' has unsupported attribute value", node.name
)
continue
# Register an initializer with the tensor value
initializer = ir.Value(
name=initializer_name,
shape=tensor.shape, # type: ignore[arg-type]
type=ir.TensorType(tensor.dtype),
const_value=tensor,
)
# TODO(titaiwang): Is it possible that the initializer name has
# been taken?
model.graph.register_initializer(initializer)
# Replace the constant node with the initilizer
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
model.graph.remove(node, safe=True)
count += 1
logger.info(
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
)
if count:
logger.info("Lifted %s constants to initializers", count)
return ir.passes.PassResult(model, modified=bool(count))


def _constant_node_attribute_to_tensor(
attr_name: str, attr_value: ir.Attr, initializer_name: str
) -> ir.Tensor | None:
"""Convert constant node attribute to tensor."""
if attr_name == "value":
tensor = attr_value.as_tensor() # type: ignore[union-attr]
elif attr_name == "value_int":
tensor = ir.Tensor(
np.array(attr_value.as_int(), dtype=np.int64), name=initializer_name
)
elif attr_name == "value_ints":
tensor = ir.Tensor(
np.array(attr_value.as_ints(), dtype=np.int64), name=initializer_name
)
elif attr_name == "value_float":
tensor = ir.Tensor(
np.array(attr_value.as_float(), dtype=np.float32), name=initializer_name
)
elif attr_name == "value_floats":
tensor = ir.Tensor(
np.array(attr_value.as_floats(), dtype=np.float32), name=initializer_name
)
elif attr_name == "value_string":
tensor = ir.Tensor(
np.array(attr_value.as_string(), dtype=np.object_), name=initializer_name
)
elif attr_name == "value_strings":
tensor = ir.Tensor(
np.array(attr_value.as_strings(), dtype=np.object_), name=initializer_name
)
else:
tensor = None
return tensor # type: ignore[return-value]
67 changes: 67 additions & 0 deletions onnxscript/ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np
import parameterized

from onnxscript import ir
from onnxscript.ir.passes.common import constant_manipulation


class TestLiftConstantsToInitializersPass(unittest.TestCase):
@parameterized.parameterized.expand(
[
(ir.DataType.FLOAT, np.float32),
(ir.DataType.INT64, np.int64),
]
)
def test_pass_with_lifting_constants_to_initializers(self, ir_dtype, numpy_dtype):
inputs = [
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
ir.Value(
name="input_b",
type=ir.TensorType(ir_dtype),
shape=ir.Shape((2, 3)),
),
]

constant_tensor = ir.tensor(np.random.rand(2, 3).astype(numpy_dtype))
attribute = ir.convenience.convert_attributes({"value": constant_tensor})
const_node = ir.Node("", "Constant", inputs=[], attributes=attribute, num_outputs=1)
add_node = ir.Node("", "Add", inputs=[inputs[0], const_node.outputs[0]])
mul_node = ir.Node("", "Mul", inputs=[add_node.outputs[0], inputs[1]])

model = ir.Model(
graph=ir.Graph(
inputs=inputs,
outputs=mul_node.outputs,
nodes=[const_node, add_node, mul_node],
opset_imports={"": 20},
),
ir_version=10,
)

# Check that the initializer is not in the graph yet
self.assertEqual(len(model.graph.initializers), 0)
# And 1 constant node
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
# Check the value
self.assertEqual(
result.model.graph.initializers[
"val_0"
].const_value, # name created by name_authority
constant_tensor,
)
# And 0 constant node
self.assertEqual(
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
)
2 changes: 2 additions & 0 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import onnxscript.ir.passes.common.constant_manipulation
import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer
from onnxscript import ir, rewriter
Expand Down Expand Up @@ -70,6 +71,7 @@ def optimize_ir(
early_stop=stop_if_no_change,
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(),
)
assert optimizer_pass.in_place
result = optimizer_pass(model)
Expand Down
Loading