diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py index 3032b33d44..3245415c31 100644 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ b/onnxscript/ir/passes/common/constant_manipulation.py @@ -18,13 +18,28 @@ class LiftConstantsToInitializersPass(ir.passes.InPlacePass): + """Lift constants to initializers. + + Attributes: + lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.) + Default to False, where only Constants with the ``value`` attribute are lifted. + """ + + def __init__(self, lift_all_constants: bool = False): + super().__init__() + self._lift_all_constants = lift_all_constants + def call(self, model: ir.Model) -> ir.passes.PassResult: - """Convert constant nodes from node belonged graph to its initializers.""" count = 0 for node in ir.traversal.RecursiveGraphIterator(model.graph): + assert node.graph is not None if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): continue - + if node.outputs[0].is_graph_output(): + logger.debug( + "Constant node '%s' is used as output, so it can't be lifted.", node.name + ) + continue constant_node_attribute = set(node.attributes.keys()) if len(constant_node_attribute) != 1: logger.debug( @@ -36,13 +51,11 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: initializer_name = node.outputs[0].name assert initializer_name is not None assert isinstance(attr_value, ir.Attr) - tensor = _constant_node_attribute_to_tensor( - attr_name, attr_value, initializer_name + tensor = self._constant_node_attribute_to_tensor( + node, attr_name, attr_value, initializer_name ) if tensor is None: - logger.debug( - "Invalid constant node '%s' has unsupported attribute value", node.name - ) + # The reason of None is logged in _constant_node_attribute_to_tensor continue # Register an initializer with the tensor value initializer = ir.Value( @@ -51,7 +64,6 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: type=ir.TensorType(tensor.dtype), const_value=tensor, ) - assert node.graph is not None assert isinstance(node.graph, ir.Graph) node.graph.register_initializer(initializer) # Replace the constant node with the initilizer @@ -65,31 +77,38 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: logger.debug("Lifted %s constants to initializers", count) return ir.passes.PassResult(model, modified=bool(count)) + def _constant_node_attribute_to_tensor( + self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str + ) -> ir.Tensor | None: + """Convert constant node attribute to tensor.""" + if not self._lift_all_constants and attr_name != "value": + logger.debug( + "Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name + ) + return None -def _constant_node_attribute_to_tensor( - attr_name: str, attr_value: ir.Attr, initializer_name: str -) -> ir.Tensor | None: - """Convert constant node attribute to tensor.""" - if attr_name == "value": - tensor = attr_value.as_tensor() # type: ignore[union-attr] - elif attr_name == "value_int": - tensor = ir.tensor(attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name) - elif attr_name == "value_ints": - tensor = ir.tensor( - attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name - ) - elif attr_name == "value_float": - tensor = ir.tensor( - attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name - ) - elif attr_name == "value_floats": - tensor = ir.tensor( - attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name - ) - elif attr_name in ("value_string", "value_strings"): - tensor = ir.StringTensor( - np.array(attr_value.value, dtype=np.bytes_), name=initializer_name - ) - else: - tensor = None - return tensor # type: ignore[return-value] + if attr_name == "value": + tensor = attr_value.as_tensor() # type: ignore[union-attr] + elif attr_name == "value_int": + tensor = ir.tensor( + attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name + ) + elif attr_name == "value_ints": + tensor = ir.tensor( + attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name + ) + elif attr_name == "value_float": + tensor = ir.tensor( + attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name + ) + elif attr_name == "value_floats": + tensor = ir.tensor( + attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name + ) + elif attr_name in ("value_string", "value_strings"): + tensor = ir.StringTensor( + np.array(attr_value.value, dtype=np.bytes_), name=initializer_name + ) + else: + tensor = None + return tensor # type: ignore[return-value] diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py index 2d1696e7fd..aee6f71e35 100644 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ b/onnxscript/ir/passes/common/constant_manipulation_test.py @@ -14,11 +14,15 @@ class TestLiftConstantsToInitializersPass(unittest.TestCase): @parameterized.parameterized.expand( [ - (ir.DataType.FLOAT,), - (ir.DataType.INT64,), + (ir.DataType.FLOAT, True), + (ir.DataType.FLOAT, False), + (ir.DataType.INT64, True), + (ir.DataType.INT64, False), ] ) - def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtype): + def test_pass_with_lifting_float_and_int_constants_to_initializers( + self, ir_dtype: ir.DataType, lift_all_constants: bool + ): inputs = [ ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))), ir.Value( @@ -51,7 +55,9 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtyp self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) # Perform lift constants to initializers - result = constant_manipulation.LiftConstantsToInitializersPass()(model) + result = constant_manipulation.LiftConstantsToInitializersPass( + lift_all_constants=lift_all_constants + )(model) self.assertTrue(result.modified) # Check that the constant node is lifted to an initializer self.assertEqual(len(result.model.graph.initializers), 1) @@ -67,7 +73,15 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(self, ir_dtyp len([node for node in result.model.graph if node.op_type == "Constant"]), 0 ) - def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): + @parameterized.parameterized.expand( + [ + (True,), + (False,), + ] + ) + def test_pass_with_lifting_constants_to_initializers_within_subgraph( + self, lift_all_constants: bool + ): input_value = ir.Value( name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) ) @@ -115,7 +129,9 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): graph=main_graph, ir_version=10, ) - result = constant_manipulation.LiftConstantsToInitializersPass()(model) + result = constant_manipulation.LiftConstantsToInitializersPass( + lift_all_constants=lift_all_constants + )(model) self.assertTrue(result.modified) # Check that the constant node is lifted to the subgraph initializers for node in ir.traversal.RecursiveGraphIterator(result.model.graph): @@ -136,16 +152,26 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(self): @parameterized.parameterized.expand( [ - (1.0, "value_float", np.float32), - (1, "value_int", np.int64), - ("hello world!", "value_string", np.bytes_), - ([1.0, 2.0, 3.0], "value_floats", np.float32), - ([1, 2, 3], "value_ints", np.int64), - (["hello world!", "thank you."], "value_strings", np.bytes_), + (1.0, "value_float", np.float32, True), + (1.0, "value_float", np.float32, False), + (1, "value_int", np.int64, True), + (1, "value_int", np.int64, False), + ("hello world!", "value_string", np.bytes_, True), + ("hello world!", "value_string", np.bytes_, False), + ([1.0, 2.0, 3.0], "value_floats", np.float32, True), + ([1.0, 2.0, 3.0], "value_floats", np.float32, False), + ([1, 2, 3], "value_ints", np.int64, True), + ([1, 2, 3], "value_ints", np.int64, False), + (["hello world!", "thank you."], "value_strings", np.bytes_, True), + (["hello world!", "thank you."], "value_strings", np.bytes_, False), ] ) def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( - self, value, constant_attribute, np_dtype + self, + value: float | int | str | list[float] | list[int] | list[str], + constant_attribute: str, + np_dtype: type[np.dtype], + lift_all_constants: bool, ): input_value = ir.Value( name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) @@ -179,11 +205,47 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) # Perform lift constants to initializers - result = constant_manipulation.LiftConstantsToInitializersPass()(model) - self.assertTrue(result.modified) - # Check that the constant node is lifted to an initializer - self.assertEqual(len(result.model.graph.initializers), 1) - np.testing.assert_array_equal( - result.model.graph.initializers["val_1"].const_value.numpy(), - np.array(constant_value, dtype=np_dtype), + result = constant_manipulation.LiftConstantsToInitializersPass( + lift_all_constants=lift_all_constants + )(model) + if lift_all_constants: + self.assertTrue(result.modified) + # Check that the constant node is lifted to an initializer + self.assertEqual(len(result.model.graph.initializers), 1) + np.testing.assert_array_equal( + result.model.graph.initializers["val_1"].const_value.numpy(), + np.array(constant_value, dtype=np_dtype), + ) + else: + self.assertFalse(result.modified) + # Check that the constant node is not lifted to an initializer + self.assertEqual(len(result.model.graph.initializers), 0) + + def test_not_lifting_constants_to_initializers_when_it_is_output(self): + input_value = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) ) + identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1) + + constant_value = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + const_node = ir.node( + "Constant", + inputs=[], + attributes={"value": constant_value}, + num_outputs=1, + ) + + model = ir.Model( + graph=ir.Graph( + inputs=[input_value], + outputs=[identity_node_input.outputs[0], const_node.outputs[0]], + nodes=[identity_node_input, const_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + result = constant_manipulation.LiftConstantsToInitializersPass()(model) + self.assertFalse(result.modified) + # Check that the constant node is not lifted to an initializer + self.assertEqual(len(result.model.graph.initializers), 0)