diff --git a/src/onnx_ir/passes/common/common_subexpression_elimination.py b/src/onnx_ir/passes/common/common_subexpression_elimination.py index a3534f42..711a040e 100644 --- a/src/onnx_ir/passes/common/common_subexpression_elimination.py +++ b/src/onnx_ir/passes/common/common_subexpression_elimination.py @@ -17,93 +17,117 @@ class CommonSubexpressionEliminationPass(ir.passes.InPlacePass): - """Eliminate common subexpression in ONNX graphs.""" + """Eliminate common subexpression in ONNX graphs. + + Attributes: + size_limit: The maximum size of the tensor to be csed. If the tensor contains + number of elements larger than size_limit, it will not be cse'd. Default is 10. + + """ + + def __init__(self, size_limit: int = 10): + """Initialize the CommonSubexpressionEliminationPass.""" + super().__init__() + self.size_limit = size_limit def call(self, model: ir.Model) -> ir.passes.PassResult: """Return the same ir.Model but with CSE applied to the graph.""" - modified = False graph = model.graph - - modified = _eliminate_common_subexpression(graph, modified) + modified = self._eliminate_common_subexpression(graph) return ir.passes.PassResult( model, modified=modified, ) - -def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool: - """Eliminate common subexpression in ONNX graphs.""" - # node to node identifier, length of outputs, inputs, and attributes - existing_node_info_to_the_node: dict[ - tuple[ - ir.OperatorIdentifier, - int, # len(outputs) - tuple[int, ...], # input ids - tuple[tuple[str, object], ...], # attributes - ], - ir.Node, - ] = {} - - for node in graph: - # Skip control flow ops like Loop and If. - control_flow_op: bool = False - # Use equality to check if the node is a common subexpression. - attributes = {} - for k, v in node.attributes.items(): - # TODO(exporter team): CSE subgraphs. - # NOTE: control flow ops like Loop and If won't be CSEd - # because attribute: graph won't match. - if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS): - control_flow_op = True + def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool: + """Eliminate common subexpression in ONNX graphs.""" + modified: bool = False + # node to node identifier, length of outputs, inputs, and attributes + existing_node_info_to_the_node: dict[ + tuple[ + ir.OperatorIdentifier, + int, # len(outputs) + tuple[int, ...], # input ids + tuple[tuple[str, object], ...], # attributes + ], + ir.Node, + ] = {} + + for node in graph: + # Skip control flow ops like Loop and If. + control_flow_op: bool = False + # Skip large tensors to avoid cse weights and bias. + large_tensor: bool = False + # Use equality to check if the node is a common subexpression. + attributes = {} + for k, v in node.attributes.items(): + # TODO(exporter team): CSE subgraphs. + # NOTE: control flow ops like Loop and If won't be CSEd + # because attribute: graph won't match. + if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS): + control_flow_op = True + break + # The attribute value could be directly taken from the original + # protobuf, so we need to make a copy of it. + value = v.value + if v.type in ( + ir.AttributeType.INTS, + ir.AttributeType.FLOATS, + ir.AttributeType.STRINGS, + ): + # For INT, FLOAT and STRING attributes, we convert them to tuples + # to ensure they are hashable. + value = tuple(value) + elif v.type is ir.AttributeType.TENSOR: + if value.size > self.size_limit: + # If the tensor is larger than the size limit, we skip it. + large_tensor = True + break + np_value = value.numpy() + + value = (np_value.shape, str(np_value.dtype), np_value.tobytes()) + attributes[k] = value + + if control_flow_op: + # If the node is a control flow op, we skip it. logger.debug("Skipping control flow op %s", node) - # The attribute value could be directly taken from the original - # protobuf, so we need to make a copy of it. - value = v.value - if v.type in ( - ir.AttributeType.INTS, - ir.AttributeType.FLOATS, - ir.AttributeType.STRINGS, - ): - # For INT, FLOAT and STRING attributes, we convert them to tuples - # to ensure they are hashable. - value = tuple(value) - attributes[k] = value - - if control_flow_op: - # If the node is a control flow op, we skip it. - logger.debug("Skipping control flow op %s", node) - continue - - if _is_non_deterministic_op(node): - # If the node is a non-deterministic op, we skip it. - logger.debug("Skipping non-deterministic op %s", node) - continue - - node_info = ( - node.op_identifier(), - len(node.outputs), - tuple(id(input) for input in node.inputs), - tuple(sorted(attributes.items())), - ) - # Check if the node is a common subexpression. - if node_info in existing_node_info_to_the_node: - # If it is, this node has an existing node with the same - # operator, number of outputs, inputs, and attributes. - # We replace the node with the existing node. - modified = True - existing_node = existing_node_info_to_the_node[node_info] - _remove_node_and_replace_values( - graph, - remove_node=node, - remove_values=node.outputs, - new_values=existing_node.outputs, + continue + + if large_tensor: + # If the node has a large tensor, we skip it. + logger.debug("Skipping large tensor in node %s", node) + continue + + if _is_non_deterministic_op(node): + # If the node is a non-deterministic op, we skip it. + logger.debug("Skipping non-deterministic op %s", node) + continue + + node_info = ( + node.op_identifier(), + len(node.outputs), + tuple(id(input) for input in node.inputs), + tuple(sorted(attributes.items())), ) - logger.debug("Reusing node %s", existing_node) - else: - # If it is not, add to the mapping. - existing_node_info_to_the_node[node_info] = node - return modified + # Check if the node is a common subexpression. + if node_info in existing_node_info_to_the_node: + # If it is, this node has an existing node with the same + # operator, number of outputs, inputs, and attributes. + # We replace the node with the existing node. + modified = True + existing_node = existing_node_info_to_the_node[node_info] + _remove_node_and_replace_values( + graph, + remove_node=node, + remove_values=node.outputs, + new_values=existing_node.outputs, + ) + logger.debug("Reusing node %s", existing_node) + else: + # If it is not, add to the mapping. + existing_node_info_to_the_node[node_info] = node + return modified def _remove_node_and_replace_values( diff --git a/src/onnx_ir/passes/common/common_subexpression_elimination_test.py b/src/onnx_ir/passes/common/common_subexpression_elimination_test.py index b993d463..1c1beafa 100644 --- a/src/onnx_ir/passes/common/common_subexpression_elimination_test.py +++ b/src/onnx_ir/passes/common/common_subexpression_elimination_test.py @@ -14,7 +14,13 @@ class TestCommonSubexpressionEliminationPass(unittest.TestCase): - def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list[int]): + def check_graph( + self, + model: ir.Model, + inputs: list[ir.Value], + delta_nodes: list[int], + size_limit: int = 10, + ): """Check if the model applied the CSE pass correctly. Args: @@ -23,6 +29,8 @@ def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list delta_nodes: The expected change in the number of nodes in the model. The length of this list should match the number of graphs in the model. (to support subgraphs in the future) + size_limit: The maximum size of the tensor to be csed. If the tensor contains + number of elements larger than size_limit, it will not be csed. Raises: AssertionError: If the model does not match the expected number of nodes or outputs. @@ -42,7 +50,9 @@ def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list original_model_session = ort.InferenceSession(model_proto.SerializeToString()) original_model_results = original_model_session.run(None, ort_inputs) - result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model) + result = common_subexpression_elimination.CommonSubexpressionEliminationPass( + size_limit=size_limit + )(model) result_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()]) # Check if the number of nodes in the model is correct @@ -204,7 +214,7 @@ def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]: model_proto = test_model.to_model_proto() model = ir.serde.deserialize_model(model_proto) self.check_graph( - model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[0, 0, 0, 0, 0] + model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[3, 0, 0, 0, 0] ) def test_the_nodes_following_control_flow_ops_are_csed(self): @@ -321,3 +331,43 @@ def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]: model_proto = test_model.to_model_proto() model = ir.serde.deserialize_model(model_proto) self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[0]) + + def test_the_constant_nodes_have_the_same_tensor_are_csed(self): + """Test if the constant nodes with the same tensor are CSEd. + + def f(x): + a = x + [1, 2] + b = x + [1, 2] + return a + b + """ + + @script() + def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]: + a = op.Add(x, op.Constant(value=np.array([1.0, 2.0], dtype=np.float32))) + b = op.Add(x, op.Constant(value=np.array([1.0, 2.0], dtype=np.float32))) + return a + b + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + # Add and Constant nodes should be CSEd + self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[2]) + + def test_the_constant_nodes_with_the_tensors_larger_than_size_limit_are_not_csed(self): + """Test if the constant nodes with the tensors larger than size limit are not CSEd. + + def f(x): + a = x + [1, 2, 3, 4] + b = x + [1, 2, 3, 4] + return a + b + """ + + @script() + def test_model(x: FLOAT[4, 4]) -> FLOAT[4, 4]: + a = op.Add(x, op.Constant(value=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32))) + b = op.Add(x, op.Constant(value=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32))) + return a + b + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + # Add and Constant nodes should not be CSEd + self.check_graph(model, [np.random.rand(4, 4)], delta_nodes=[0], size_limit=3)