Skip to content

[Pass] Fix bugs in LiftConstantsToInitializersPass #2189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 54 additions & 35 deletions onnxscript/ir/passes/common/constant_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,28 @@


class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
"""Lift constants to initializers.

Attributes:
lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.)
Default to False, where only Constants with the ``value`` attribute are lifted.
"""

def __init__(self, lift_all_constants: bool = False):
super().__init__()
self._lift_all_constants = lift_all_constants

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Convert constant nodes from node belonged graph to its initializers."""
count = 0
for node in ir.traversal.RecursiveGraphIterator(model.graph):
assert node.graph is not None
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
continue

if node.outputs[0].is_graph_output():
logger.debug(
"Constant node '%s' is used as output, so it can't be lifted.", node.name
)
continue
constant_node_attribute = set(node.attributes.keys())
if len(constant_node_attribute) != 1:
logger.debug(
Expand All @@ -36,13 +51,11 @@
initializer_name = node.outputs[0].name
assert initializer_name is not None
assert isinstance(attr_value, ir.Attr)
tensor = _constant_node_attribute_to_tensor(
attr_name, attr_value, initializer_name
tensor = self._constant_node_attribute_to_tensor(
node, attr_name, attr_value, initializer_name
)
if tensor is None:
logger.debug(
"Invalid constant node '%s' has unsupported attribute value", node.name
)
# The reason of None is logged in _constant_node_attribute_to_tensor
continue
# Register an initializer with the tensor value
initializer = ir.Value(
Expand All @@ -51,7 +64,6 @@
type=ir.TensorType(tensor.dtype),
const_value=tensor,
)
assert node.graph is not None
assert isinstance(node.graph, ir.Graph)
node.graph.register_initializer(initializer)
# Replace the constant node with the initilizer
Expand All @@ -65,31 +77,38 @@
logger.debug("Lifted %s constants to initializers", count)
return ir.passes.PassResult(model, modified=bool(count))

def _constant_node_attribute_to_tensor(
self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str
) -> ir.Tensor | None:
"""Convert constant node attribute to tensor."""
if not self._lift_all_constants and attr_name != "value":
logger.debug(
"Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name
)
return None

def _constant_node_attribute_to_tensor(
attr_name: str, attr_value: ir.Attr, initializer_name: str
) -> ir.Tensor | None:
"""Convert constant node attribute to tensor."""
if attr_name == "value":
tensor = attr_value.as_tensor() # type: ignore[union-attr]
elif attr_name == "value_int":
tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name)
elif attr_name == "value_ints":
tensor = ir.tensor(
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
)
elif attr_name == "value_float":
tensor = ir.tensor(
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name == "value_floats":
tensor = ir.tensor(
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name in ("value_string", "value_strings"):
tensor = ir.StringTensor(
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
)
else:
tensor = None
return tensor # type: ignore[return-value]
if attr_name == "value":
tensor = attr_value.as_tensor() # type: ignore[union-attr]
elif attr_name == "value_int":
tensor = ir.tensor(
attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
)
elif attr_name == "value_ints":
tensor = ir.tensor(
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
)
elif attr_name == "value_float":
tensor = ir.tensor(
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name == "value_floats":
tensor = ir.tensor(
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
)
elif attr_name in ("value_string", "value_strings"):
tensor = ir.StringTensor(
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
)
else:
tensor = None

Check warning on line 113 in onnxscript/ir/passes/common/constant_manipulation.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/constant_manipulation.py#L113

Added line #L113 was not covered by tests
return tensor # type: ignore[return-value]
102 changes: 82 additions & 20 deletions onnxscript/ir/passes/common/constant_manipulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
class TestLiftConstantsToInitializersPass(unittest.TestCase):
@parameterized.parameterized.expand(
[
(ir.DataType.FLOAT,),
(ir.DataType.INT64,),
(ir.DataType.FLOAT, True),
(ir.DataType.FLOAT, False),
(ir.DataType.INT64, True),
(ir.DataType.INT64, False),
]
)
def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtype):
def test_pass_with_lifting_float_and_int_constants_to_initializers(
self, ir_dtype: ir.DataType, lift_all_constants: bool
):
inputs = [
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
ir.Value(
Expand Down Expand Up @@ -51,7 +55,9 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtyp
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
result = constant_manipulation.LiftConstantsToInitializersPass(
lift_all_constants=lift_all_constants
)(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
Expand All @@ -67,7 +73,15 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtyp
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
)

def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
@parameterized.parameterized.expand(
[
(True,),
(False,),
]
)
def test_pass_with_lifting_constants_to_initializers_within_subgraph(
self, lift_all_constants: bool
):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)
Expand Down Expand Up @@ -115,7 +129,9 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):
graph=main_graph,
ir_version=10,
)
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
result = constant_manipulation.LiftConstantsToInitializersPass(
lift_all_constants=lift_all_constants
)(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to the subgraph initializers
for node in ir.traversal.RecursiveGraphIterator(result.model.graph):
Expand All @@ -136,16 +152,26 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self):

@parameterized.parameterized.expand(
[
(1.0, "value_float", np.float32),
(1, "value_int", np.int64),
("hello world!", "value_string", np.bytes_),
([1.0, 2.0, 3.0], "value_floats", np.float32),
([1, 2, 3], "value_ints", np.int64),
(["hello world!", "thank you."], "value_strings", np.bytes_),
(1.0, "value_float", np.float32, True),
(1.0, "value_float", np.float32, False),
(1, "value_int", np.int64, True),
(1, "value_int", np.int64, False),
("hello world!", "value_string", np.bytes_, True),
("hello world!", "value_string", np.bytes_, False),
([1.0, 2.0, 3.0], "value_floats", np.float32, True),
([1.0, 2.0, 3.0], "value_floats", np.float32, False),
([1, 2, 3], "value_ints", np.int64, True),
([1, 2, 3], "value_ints", np.int64, False),
(["hello world!", "thank you."], "value_strings", np.bytes_, True),
(["hello world!", "thank you."], "value_strings", np.bytes_, False),
]
)
def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
self, value, constant_attribute, np_dtype
self,
value: float | int | str | list[float] | list[int] | list[str],
constant_attribute: str,
np_dtype: type[np.dtype],
lift_all_constants: bool,
):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
Expand Down Expand Up @@ -179,11 +205,47 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1)

# Perform lift constants to initializers
result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
np.testing.assert_array_equal(
result.model.graph.initializers["val_1"].const_value.numpy(),
np.array(constant_value, dtype=np_dtype),
result = constant_manipulation.LiftConstantsToInitializersPass(
lift_all_constants=lift_all_constants
)(model)
if lift_all_constants:
self.assertTrue(result.modified)
# Check that the constant node is lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 1)
np.testing.assert_array_equal(
result.model.graph.initializers["val_1"].const_value.numpy(),
np.array(constant_value, dtype=np_dtype),
)
else:
self.assertFalse(result.modified)
# Check that the constant node is not lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 0)

def test_not_lifting_constants_to_initializers_when_it_is_output(self):
input_value = ir.Value(
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
)
identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1)

constant_value = ir.tensor(np.random.rand(2, 3).astype(np.float32))
const_node = ir.node(
"Constant",
inputs=[],
attributes={"value": constant_value},
num_outputs=1,
)

model = ir.Model(
graph=ir.Graph(
inputs=[input_value],
outputs=[identity_node_input.outputs[0], const_node.outputs[0]],
nodes=[identity_node_input, const_node],
opset_imports={"": 20},
),
ir_version=10,
)

result = constant_manipulation.LiftConstantsToInitializersPass()(model)
self.assertFalse(result.modified)
# Check that the constant node is not lifted to an initializer
self.assertEqual(len(result.model.graph.initializers), 0)
Loading