Skip to content

Support common subexpression elimination pass (CSE) #2304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6563ad
draft tests
titaiwangms May 13, 2025
2f55000
add more tests
titaiwangms May 13, 2025
d567ba1
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 13, 2025
d1082d1
Update onnxscript/ir/passes/common/common_subexpression_elimination.py
titaiwangms May 13, 2025
ea873e4
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 16, 2025
017ef27
inplace
titaiwangms May 16, 2025
2a370e4
add recursive function but one test is still faling
titaiwangms May 16, 2025
d490072
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
706b86a
revert subgraph cse support
titaiwangms May 27, 2025
dcbc08d
add another test for subgraph
titaiwangms May 27, 2025
55d32c7
add the pass to optimization
titaiwangms May 27, 2025
c5cab5b
make repeated contained attributes hashable
titaiwangms May 27, 2025
be2c008
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
da05efb
delete previous_node and only delete the node
titaiwangms May 28, 2025
ce2bc54
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 29, 2025
1d4fd53
create and use a stateless function
titaiwangms May 29, 2025
5cfd94e
keep the names of graph output
titaiwangms May 29, 2025
44f6042
address reviews
titaiwangms May 29, 2025
ab212d6
resolve conflict
titaiwangms May 30, 2025
9c2d134
revert
titaiwangms May 30, 2025
6a43bfb
fix lint
titaiwangms May 30, 2025
3b1b19f
separate import common_subexpression_elimination
titaiwangms May 30, 2025
9fd8948
remove cse from optimizer
titaiwangms May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,7 +2734,7 @@ def __init__(
model_version: int | None = None,
doc_string: str | None = None,
functions: Sequence[Function] = (),
meta_data_props: dict[str, str] | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
self.graph: Graph = graph
self.ir_version = ir_version
Expand All @@ -2745,7 +2745,7 @@ def __init__(
self.doc_string = doc_string
self._functions = {func.identifier(): func for func in functions}
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props: dict[str, str] | None = meta_data_props
self._metadata_props: dict[str, str] | None = metadata_props

@property
def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
Expand Down
163 changes: 163 additions & 0 deletions onnxscript/ir/passes/common/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Eliminate common subexpression in ONNX graphs."""

from __future__ import annotations

__all__ = [
"CommonSubexpressionEliminationPass",
]

import logging

from onnxscript import ir

logger = logging.getLogger(__name__)


class CommonSubexpressionEliminationPass(ir.passes.FunctionalPass):
"""Eliminate common subexpression in ONNX graphs."""

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Return the same ir.Model but with CSE applied to the graph."""
modified = False
# 1. Initialize a new graph. It will be used to store the new nodes.
# and replace the old graph.
old_graph = model.graph
# Values and nodes point to the old graph, so they need to be
# created in the new graph.
new_graph = ir.Graph(
inputs=[],
outputs=[],
nodes=[],
initializers=[],
name=old_graph.name,
metadata_props=old_graph.metadata_props,
opset_imports=old_graph.opset_imports,
doc_string=old_graph.doc_string,
)
# 2. Create a mapping from olds to news.
old_node_hash_to_new_node: dict[int, ir.Node] = {}
old_value_to_new_value: dict[ir.Value, ir.Value] = {}

# 3. Create inputs and initializers in the new graph
# from the old graph.
for input in old_graph.inputs:
new_input = _copy_value(input)
new_graph.inputs.append(new_input)
old_value_to_new_value[input] = new_input
logger.debug("Adding input %s", new_input.name)
for initializer in old_graph.initializers.values():
new_initializer = _copy_value(initializer)
new_graph.register_initializer(new_initializer)
old_value_to_new_value[initializer] = new_initializer
logger.debug("Adding initializer %s", new_initializer.name)

# 4. Create nodes in the new graph from the old graph.
for old_node in old_graph:
# 4.0. Iterate and update node inputs.
old_node_inputs: list[ir.Value] = []
for old_input in old_node.inputs:
assert old_input is not None
if old_input in old_value_to_new_value:
old_node_inputs.append(old_value_to_new_value[old_input])
else:
old_node_inputs.append(old_input)
# 4.1. Construct the (node, inputs, attributes) hash to
# check if the node is a common subexpression.
# Attr is not hashable
attributes = {}
for k, v in old_node.attributes.items():
assert isinstance(v, ir.Attr)
attributes[k] = v.value
hash_value = hash(
(
old_node.op_identifier(),
tuple(old_node_inputs),
tuple(attributes),
)
)
# TODO(titaiwang): Subgraphs are not supported yet.
# TODO(titaiwang): Skip control flow nodes?
# 4.2. Check if the node is a common subexpression.
if hash_value in old_node_hash_to_new_node:
# 4.2.1. If it is, this node is already in the new graph, so
# we don't need to create a new node.
modified = True
new_node = old_node_hash_to_new_node[hash_value]
logger.debug("Reusing node %s", new_node.name)
else:
# 4.2.2. If it is not, create a new node and add it to the graph.
new_node = _copy_node(old_node, old_value_to_new_value)
new_graph.append(new_node)
old_node_hash_to_new_node[hash_value] = new_node
# 4.3 Add the node outputs to the mapping.
old_value_to_new_value.update(dict(zip(old_node.outputs, new_node.outputs)))
# 5. Create outputs in the new graph from the old graph.
for output in old_graph.outputs:
new_output = old_value_to_new_value[output]
new_graph.outputs.append(new_output)
logger.debug("Adding output %s", new_output.name)
# 6. Replace the old graph with the new graph.
model = _copy_model(original_model=model, new_graph=new_graph)
return ir.passes.PassResult(
model,
modified=modified,
)


def _copy_value(original_value: ir.Value) -> ir.Value:
"""Copy an IR value."""
new_input = ir.Value(
name=original_value.name,
shape=original_value.shape,
type=original_value.type,
doc_string=original_value.doc_string,
const_value=original_value.const_value,
)
return new_input


def _copy_node(
original_node: ir.Node, old_value_to_new_value: dict[ir.Value, ir.Value]
) -> ir.Node:
"""Copy an IR node."""
new_inputs: list[ir.Value] = []
for original_input in original_node.inputs:
if original_input in old_value_to_new_value:
new_inputs.append(old_value_to_new_value[original_input])
else:
raise ValueError(f"Input {original_input} not found in old_value_to_new_value")
new_node = ir.node(
domain=original_node.domain,
op_type=original_node.op_type,
inputs=new_inputs,
attributes=original_node.attributes,
overload=original_node.overload,
num_outputs=len(original_node.outputs),
metadata_props=original_node.metadata_props,
doc_string=original_node.doc_string,
name=original_node.name,
version=original_node.version,
)
return new_node


def _copy_model(
original_model: ir.Model,
new_graph: ir.Graph,
) -> ir.Model:
"""Copy an IR model but with the new graph."""
functions = tuple(original_model.functions.values())
new_model = ir.Model(
graph=new_graph,
ir_version=original_model.ir_version,
producer_name=original_model.producer_name,
producer_version=original_model.producer_version,
domain=original_model.domain,
model_version=original_model.model_version,
doc_string=original_model.doc_string,
functions=functions,
metadata_props=original_model.metadata_props,
)
return new_model
115 changes: 115 additions & 0 deletions onnxscript/ir/passes/common/common_subexpression_elimination_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np
import onnxruntime as ort

from onnxscript import FLOAT, ir, script
from onnxscript import opset18 as op
from onnxscript.ir.passes.common import common_subexpression_elimination


class TestCommonSubexpressionEliminationPass(unittest.TestCase):
def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: int = 0):
"""Check if the model applied the CSE pass correctly."""
result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model)
# Check if the number of nodes in the model is correct
self.assertEqual(len(model.graph), len(result.model.graph) + delta_nodes)
self.assertEqual(result.modified, len(model.graph) > len(result.model.graph))

model_proto = ir.serde.serialize_model(model)
result_proto = ir.serde.serialize_model(result.model)
# Check if the models produce the same output
# with the same inputs
ort_inputs = {
k.name: np.random.rand(*v.shape).astype(np.float32)
for k, v in zip(model.graph.inputs, inputs)
}
ort_session = ort.InferenceSession(model_proto.SerializeToString())
ort_results = ort_session.run(None, ort_inputs)
result_session = ort.InferenceSession(result_proto.SerializeToString())
result_results = result_session.run(None, ort_inputs)
for idx, ort_result in enumerate(ort_results):
np.testing.assert_allclose(ort_result, result_results[idx], rtol=1e-5, atol=1e-5)

def test_two_branches_with_the_same_operations_is_csed(self):
"""Test if two branches with the same operations are CSEd.

def test_simple(self):
def f(x):
a = x.cos()
b = x.cos()
c = a + a
d = b + b
return c + d

x = torch.randn(2, 2)
"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
a = op.Cos(x)
b = op.Cos(x)
c = a + a
d = b + b
return c + d

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)

self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=2)

def test_more_operations_in_two_branches_with_the_same_operations_is_csed(self):
"""Test if two branches with the same operations are CSEd.

def test_simple(self):
def f(x):
a = x.cos().sin()
b = x.cos().sin()
c = a + a
d = b + b
return c + d

x = torch.randn(2, 2)
"""

@script()
def test_model(x: FLOAT[1]) -> FLOAT[1]:
a = op.Sin(op.Cos(x))
b = op.Sin(op.Cos(x))
c = a + a
d = b + b
return c + d

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(1)], delta_nodes=3)

def test_multiple_same_ops_with_attributes_are_csed(self):
"""Test if multiple same ops are CSEd.

def f(x):
a = x.sum()
b = x.sum()
c = x.sum()
d = x.sum()
return a + b + c + d

x = torch.randn(2, 2)

"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
a = op.ReduceSum(x, keepdims=False)
b = op.ReduceSum(x, keepdims=False)
c = op.ReduceSum(x, keepdims=False)
d = op.ReduceSum(x, keepdims=False)
return a + b + c + d

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=3)
2 changes: 1 addition & 1 deletion onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
model_version=_get_field(proto, "model_version"),
doc_string=_get_field(proto, "doc_string"),
functions=functions,
meta_data_props=deserialize_metadata_props(proto.metadata_props),
metadata_props=deserialize_metadata_props(proto.metadata_props),
)

# Handle experimental value info for functions created by the dynamo exporter in IR version 9
Expand Down
Loading