Skip to content

[Pass] Support CSE constant nodes #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 99 additions & 75 deletions src/onnx_ir/passes/common/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,93 +17,117 @@


class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
"""Eliminate common subexpression in ONNX graphs."""
"""Eliminate common subexpression in ONNX graphs.

Attributes:
size_limit: The maximum size of the tensor to be csed. If the tensor contains
number of elements larger than size_limit, it will not be cse'd. Default is 10.

"""

def __init__(self, size_limit: int = 10):
"""Initialize the CommonSubexpressionEliminationPass."""
super().__init__()
self.size_limit = size_limit

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Return the same ir.Model but with CSE applied to the graph."""
modified = False
graph = model.graph

modified = _eliminate_common_subexpression(graph, modified)
modified = self._eliminate_common_subexpression(graph)

return ir.passes.PassResult(
model,
modified=modified,
)


def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
"""Eliminate common subexpression in ONNX graphs."""
# node to node identifier, length of outputs, inputs, and attributes
existing_node_info_to_the_node: dict[
tuple[
ir.OperatorIdentifier,
int, # len(outputs)
tuple[int, ...], # input ids
tuple[tuple[str, object], ...], # attributes
],
ir.Node,
] = {}

for node in graph:
# Skip control flow ops like Loop and If.
control_flow_op: bool = False
# Use equality to check if the node is a common subexpression.
attributes = {}
for k, v in node.attributes.items():
# TODO(exporter team): CSE subgraphs.
# NOTE: control flow ops like Loop and If won't be CSEd
# because attribute: graph won't match.
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
control_flow_op = True
def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
"""Eliminate common subexpression in ONNX graphs."""
modified: bool = False
# node to node identifier, length of outputs, inputs, and attributes
existing_node_info_to_the_node: dict[
tuple[
ir.OperatorIdentifier,
int, # len(outputs)
tuple[int, ...], # input ids
tuple[tuple[str, object], ...], # attributes
],
ir.Node,
] = {}

for node in graph:
# Skip control flow ops like Loop and If.
control_flow_op: bool = False
# Skip large tensors to avoid cse weights and bias.
large_tensor: bool = False
# Use equality to check if the node is a common subexpression.
attributes = {}
for k, v in node.attributes.items():
# TODO(exporter team): CSE subgraphs.
# NOTE: control flow ops like Loop and If won't be CSEd
# because attribute: graph won't match.
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
control_flow_op = True
break
# The attribute value could be directly taken from the original
# protobuf, so we need to make a copy of it.
value = v.value
if v.type in (
ir.AttributeType.INTS,
ir.AttributeType.FLOATS,
ir.AttributeType.STRINGS,
):
# For INT, FLOAT and STRING attributes, we convert them to tuples
# to ensure they are hashable.
value = tuple(value)
Comment on lines +73 to +81
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be improved: value_ints is the same as 1d tensor value int64. Maybe leverage the new

def get_const_tensor(
for this?

elif v.type is ir.AttributeType.TENSOR:
if value.size > self.size_limit:
# If the tensor is larger than the size limit, we skip it.
large_tensor = True
break
np_value = value.numpy()

value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
attributes[k] = value

if control_flow_op:
# If the node is a control flow op, we skip it.
logger.debug("Skipping control flow op %s", node)
# The attribute value could be directly taken from the original
# protobuf, so we need to make a copy of it.
value = v.value
if v.type in (
ir.AttributeType.INTS,
ir.AttributeType.FLOATS,
ir.AttributeType.STRINGS,
):
# For INT, FLOAT and STRING attributes, we convert them to tuples
# to ensure they are hashable.
value = tuple(value)
attributes[k] = value

if control_flow_op:
# If the node is a control flow op, we skip it.
logger.debug("Skipping control flow op %s", node)
continue

if _is_non_deterministic_op(node):
# If the node is a non-deterministic op, we skip it.
logger.debug("Skipping non-deterministic op %s", node)
continue

node_info = (
node.op_identifier(),
len(node.outputs),
tuple(id(input) for input in node.inputs),
tuple(sorted(attributes.items())),
)
# Check if the node is a common subexpression.
if node_info in existing_node_info_to_the_node:
# If it is, this node has an existing node with the same
# operator, number of outputs, inputs, and attributes.
# We replace the node with the existing node.
modified = True
existing_node = existing_node_info_to_the_node[node_info]
_remove_node_and_replace_values(
graph,
remove_node=node,
remove_values=node.outputs,
new_values=existing_node.outputs,
continue

if large_tensor:
# If the node has a large tensor, we skip it.
logger.debug("Skipping large tensor in node %s", node)
continue

if _is_non_deterministic_op(node):
# If the node is a non-deterministic op, we skip it.
logger.debug("Skipping non-deterministic op %s", node)
continue

node_info = (
node.op_identifier(),
len(node.outputs),
tuple(id(input) for input in node.inputs),
tuple(sorted(attributes.items())),
)
logger.debug("Reusing node %s", existing_node)
else:
# If it is not, add to the mapping.
existing_node_info_to_the_node[node_info] = node
return modified
# Check if the node is a common subexpression.
if node_info in existing_node_info_to_the_node:
# If it is, this node has an existing node with the same
# operator, number of outputs, inputs, and attributes.
# We replace the node with the existing node.
modified = True
existing_node = existing_node_info_to_the_node[node_info]
_remove_node_and_replace_values(
graph,
remove_node=node,
remove_values=node.outputs,
new_values=existing_node.outputs,
)
logger.debug("Reusing node %s", existing_node)
else:
# If it is not, add to the mapping.
existing_node_info_to_the_node[node_info] = node
return modified


def _remove_node_and_replace_values(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@


class TestCommonSubexpressionEliminationPass(unittest.TestCase):
def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list[int]):
def check_graph(
self,
model: ir.Model,
inputs: list[ir.Value],
delta_nodes: list[int],
size_limit: int = 10,
):
"""Check if the model applied the CSE pass correctly.

Args:
Expand All @@ -23,6 +29,8 @@ def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list
delta_nodes: The expected change in the number of nodes in the model.
The length of this list should match the number of graphs
in the model. (to support subgraphs in the future)
size_limit: The maximum size of the tensor to be csed. If the tensor contains
number of elements larger than size_limit, it will not be csed.

Raises:
AssertionError: If the model does not match the expected number of nodes or outputs.
Expand All @@ -42,7 +50,9 @@ def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list
original_model_session = ort.InferenceSession(model_proto.SerializeToString())
original_model_results = original_model_session.run(None, ort_inputs)

result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model)
result = common_subexpression_elimination.CommonSubexpressionEliminationPass(
size_limit=size_limit
)(model)

result_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()])
# Check if the number of nodes in the model is correct
Expand Down Expand Up @@ -204,7 +214,7 @@ def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]:
model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(
model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[0, 0, 0, 0, 0]
model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[3, 0, 0, 0, 0]
)

def test_the_nodes_following_control_flow_ops_are_csed(self):
Expand Down Expand Up @@ -321,3 +331,43 @@ def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[0])

def test_the_constant_nodes_have_the_same_tensor_are_csed(self):
"""Test if the constant nodes with the same tensor are CSEd.

def f(x):
a = x + [1, 2]
b = x + [1, 2]
return a + b
"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
a = op.Add(x, op.Constant(value=np.array([1.0, 2.0], dtype=np.float32)))
b = op.Add(x, op.Constant(value=np.array([1.0, 2.0], dtype=np.float32)))
return a + b

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
# Add and Constant nodes should be CSEd
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[2])

def test_the_constant_nodes_with_the_tensors_larger_than_size_limit_are_not_csed(self):
"""Test if the constant nodes with the tensors larger than size limit are not CSEd.

def f(x):
a = x + [1, 2, 3, 4]
b = x + [1, 2, 3, 4]
return a + b
"""

@script()
def test_model(x: FLOAT[4, 4]) -> FLOAT[4, 4]:
a = op.Add(x, op.Constant(value=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)))
b = op.Add(x, op.Constant(value=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)))
return a + b

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
# Add and Constant nodes should not be CSEd
self.check_graph(model, [np.random.rand(4, 4)], delta_nodes=[0], size_limit=3)