Skip to content

Commit 62e750c

Browse files
[Pass] Support CSE constant nodes (#92)
If the constant is smaller than size limit, it's csed. Use tobytes to hash TENSOR when it's smaller than size_limit (parameter of the pass). --------- Signed-off-by: Ti-Tai Wang <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent d8fa011 commit 62e750c

File tree

2 files changed

+152
-78
lines changed

2 files changed

+152
-78
lines changed

src/onnx_ir/passes/common/common_subexpression_elimination.py

Lines changed: 99 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -17,93 +17,117 @@
1717

1818

1919
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20-
"""Eliminate common subexpression in ONNX graphs."""
20+
"""Eliminate common subexpression in ONNX graphs.
21+
22+
Attributes:
23+
size_limit: The maximum size of the tensor to be csed. If the tensor contains
24+
number of elements larger than size_limit, it will not be cse'd. Default is 10.
25+
26+
"""
27+
28+
def __init__(self, size_limit: int = 10):
29+
"""Initialize the CommonSubexpressionEliminationPass."""
30+
super().__init__()
31+
self.size_limit = size_limit
2132

2233
def call(self, model: ir.Model) -> ir.passes.PassResult:
2334
"""Return the same ir.Model but with CSE applied to the graph."""
24-
modified = False
2535
graph = model.graph
26-
27-
modified = _eliminate_common_subexpression(graph, modified)
36+
modified = self._eliminate_common_subexpression(graph)
2837

2938
return ir.passes.PassResult(
3039
model,
3140
modified=modified,
3241
)
3342

34-
35-
def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
36-
"""Eliminate common subexpression in ONNX graphs."""
37-
# node to node identifier, length of outputs, inputs, and attributes
38-
existing_node_info_to_the_node: dict[
39-
tuple[
40-
ir.OperatorIdentifier,
41-
int, # len(outputs)
42-
tuple[int, ...], # input ids
43-
tuple[tuple[str, object], ...], # attributes
44-
],
45-
ir.Node,
46-
] = {}
47-
48-
for node in graph:
49-
# Skip control flow ops like Loop and If.
50-
control_flow_op: bool = False
51-
# Use equality to check if the node is a common subexpression.
52-
attributes = {}
53-
for k, v in node.attributes.items():
54-
# TODO(exporter team): CSE subgraphs.
55-
# NOTE: control flow ops like Loop and If won't be CSEd
56-
# because attribute: graph won't match.
57-
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
58-
control_flow_op = True
43+
def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
44+
"""Eliminate common subexpression in ONNX graphs."""
45+
modified: bool = False
46+
# node to node identifier, length of outputs, inputs, and attributes
47+
existing_node_info_to_the_node: dict[
48+
tuple[
49+
ir.OperatorIdentifier,
50+
int, # len(outputs)
51+
tuple[int, ...], # input ids
52+
tuple[tuple[str, object], ...], # attributes
53+
],
54+
ir.Node,
55+
] = {}
56+
57+
for node in graph:
58+
# Skip control flow ops like Loop and If.
59+
control_flow_op: bool = False
60+
# Skip large tensors to avoid cse weights and bias.
61+
large_tensor: bool = False
62+
# Use equality to check if the node is a common subexpression.
63+
attributes = {}
64+
for k, v in node.attributes.items():
65+
# TODO(exporter team): CSE subgraphs.
66+
# NOTE: control flow ops like Loop and If won't be CSEd
67+
# because attribute: graph won't match.
68+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
69+
control_flow_op = True
70+
break
71+
# The attribute value could be directly taken from the original
72+
# protobuf, so we need to make a copy of it.
73+
value = v.value
74+
if v.type in (
75+
ir.AttributeType.INTS,
76+
ir.AttributeType.FLOATS,
77+
ir.AttributeType.STRINGS,
78+
):
79+
# For INT, FLOAT and STRING attributes, we convert them to tuples
80+
# to ensure they are hashable.
81+
value = tuple(value)
82+
elif v.type is ir.AttributeType.TENSOR:
83+
if value.size > self.size_limit:
84+
# If the tensor is larger than the size limit, we skip it.
85+
large_tensor = True
86+
break
87+
np_value = value.numpy()
88+
89+
value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
90+
attributes[k] = value
91+
92+
if control_flow_op:
93+
# If the node is a control flow op, we skip it.
5994
logger.debug("Skipping control flow op %s", node)
60-
# The attribute value could be directly taken from the original
61-
# protobuf, so we need to make a copy of it.
62-
value = v.value
63-
if v.type in (
64-
ir.AttributeType.INTS,
65-
ir.AttributeType.FLOATS,
66-
ir.AttributeType.STRINGS,
67-
):
68-
# For INT, FLOAT and STRING attributes, we convert them to tuples
69-
# to ensure they are hashable.
70-
value = tuple(value)
71-
attributes[k] = value
72-
73-
if control_flow_op:
74-
# If the node is a control flow op, we skip it.
75-
logger.debug("Skipping control flow op %s", node)
76-
continue
77-
78-
if _is_non_deterministic_op(node):
79-
# If the node is a non-deterministic op, we skip it.
80-
logger.debug("Skipping non-deterministic op %s", node)
81-
continue
82-
83-
node_info = (
84-
node.op_identifier(),
85-
len(node.outputs),
86-
tuple(id(input) for input in node.inputs),
87-
tuple(sorted(attributes.items())),
88-
)
89-
# Check if the node is a common subexpression.
90-
if node_info in existing_node_info_to_the_node:
91-
# If it is, this node has an existing node with the same
92-
# operator, number of outputs, inputs, and attributes.
93-
# We replace the node with the existing node.
94-
modified = True
95-
existing_node = existing_node_info_to_the_node[node_info]
96-
_remove_node_and_replace_values(
97-
graph,
98-
remove_node=node,
99-
remove_values=node.outputs,
100-
new_values=existing_node.outputs,
95+
continue
96+
97+
if large_tensor:
98+
# If the node has a large tensor, we skip it.
99+
logger.debug("Skipping large tensor in node %s", node)
100+
continue
101+
102+
if _is_non_deterministic_op(node):
103+
# If the node is a non-deterministic op, we skip it.
104+
logger.debug("Skipping non-deterministic op %s", node)
105+
continue
106+
107+
node_info = (
108+
node.op_identifier(),
109+
len(node.outputs),
110+
tuple(id(input) for input in node.inputs),
111+
tuple(sorted(attributes.items())),
101112
)
102-
logger.debug("Reusing node %s", existing_node)
103-
else:
104-
# If it is not, add to the mapping.
105-
existing_node_info_to_the_node[node_info] = node
106-
return modified
113+
# Check if the node is a common subexpression.
114+
if node_info in existing_node_info_to_the_node:
115+
# If it is, this node has an existing node with the same
116+
# operator, number of outputs, inputs, and attributes.
117+
# We replace the node with the existing node.
118+
modified = True
119+
existing_node = existing_node_info_to_the_node[node_info]
120+
_remove_node_and_replace_values(
121+
graph,
122+
remove_node=node,
123+
remove_values=node.outputs,
124+
new_values=existing_node.outputs,
125+
)
126+
logger.debug("Reusing node %s", existing_node)
127+
else:
128+
# If it is not, add to the mapping.
129+
existing_node_info_to_the_node[node_info] = node
130+
return modified
107131

108132

109133
def _remove_node_and_replace_values(

src/onnx_ir/passes/common/common_subexpression_elimination_test.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515

1616
class TestCommonSubexpressionEliminationPass(unittest.TestCase):
17-
def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list[int]):
17+
def check_graph(
18+
self,
19+
model: ir.Model,
20+
inputs: list[ir.Value],
21+
delta_nodes: list[int],
22+
size_limit: int = 10,
23+
):
1824
"""Check if the model applied the CSE pass correctly.
1925
2026
Args:
@@ -23,6 +29,8 @@ def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list
2329
delta_nodes: The expected change in the number of nodes in the model.
2430
The length of this list should match the number of graphs
2531
in the model. (to support subgraphs in the future)
32+
size_limit: The maximum size of the tensor to be csed. If the tensor contains
33+
number of elements larger than size_limit, it will not be csed.
2634
2735
Raises:
2836
AssertionError: If the model does not match the expected number of nodes or outputs.
@@ -42,7 +50,9 @@ def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list
4250
original_model_session = ort.InferenceSession(model_proto.SerializeToString())
4351
original_model_results = original_model_session.run(None, ort_inputs)
4452

45-
result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model)
53+
result = common_subexpression_elimination.CommonSubexpressionEliminationPass(
54+
size_limit=size_limit
55+
)(model)
4656

4757
result_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()])
4858
# Check if the number of nodes in the model is correct
@@ -204,7 +214,7 @@ def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]:
204214
model_proto = test_model.to_model_proto()
205215
model = ir.serde.deserialize_model(model_proto)
206216
self.check_graph(
207-
model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[0, 0, 0, 0, 0]
217+
model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[3, 0, 0, 0, 0]
208218
)
209219

210220
def test_the_nodes_following_control_flow_ops_are_csed(self):
@@ -321,3 +331,43 @@ def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
321331
model_proto = test_model.to_model_proto()
322332
model = ir.serde.deserialize_model(model_proto)
323333
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[0])
334+
335+
def test_the_constant_nodes_have_the_same_tensor_are_csed(self):
336+
"""Test if the constant nodes with the same tensor are CSEd.
337+
338+
def f(x):
339+
a = x + [1, 2]
340+
b = x + [1, 2]
341+
return a + b
342+
"""
343+
344+
@script()
345+
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
346+
a = op.Add(x, op.Constant(value=np.array([1.0, 2.0], dtype=np.float32)))
347+
b = op.Add(x, op.Constant(value=np.array([1.0, 2.0], dtype=np.float32)))
348+
return a + b
349+
350+
model_proto = test_model.to_model_proto()
351+
model = ir.serde.deserialize_model(model_proto)
352+
# Add and Constant nodes should be CSEd
353+
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[2])
354+
355+
def test_the_constant_nodes_with_the_tensors_larger_than_size_limit_are_not_csed(self):
356+
"""Test if the constant nodes with the tensors larger than size limit are not CSEd.
357+
358+
def f(x):
359+
a = x + [1, 2, 3, 4]
360+
b = x + [1, 2, 3, 4]
361+
return a + b
362+
"""
363+
364+
@script()
365+
def test_model(x: FLOAT[4, 4]) -> FLOAT[4, 4]:
366+
a = op.Add(x, op.Constant(value=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)))
367+
b = op.Add(x, op.Constant(value=np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)))
368+
return a + b
369+
370+
model_proto = test_model.to_model_proto()
371+
model = ir.serde.deserialize_model(model_proto)
372+
# Add and Constant nodes should not be CSEd
373+
self.check_graph(model, [np.random.rand(4, 4)], delta_nodes=[0], size_limit=3)

0 commit comments

Comments
 (0)