Skip to content

Support common subexpression elimination pass (CSE) #2304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6563ad
draft tests
titaiwangms May 13, 2025
2f55000
add more tests
titaiwangms May 13, 2025
d567ba1
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 13, 2025
d1082d1
Update onnxscript/ir/passes/common/common_subexpression_elimination.py
titaiwangms May 13, 2025
ea873e4
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 16, 2025
017ef27
inplace
titaiwangms May 16, 2025
2a370e4
add recursive function but one test is still faling
titaiwangms May 16, 2025
d490072
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
706b86a
revert subgraph cse support
titaiwangms May 27, 2025
dcbc08d
add another test for subgraph
titaiwangms May 27, 2025
55d32c7
add the pass to optimization
titaiwangms May 27, 2025
c5cab5b
make repeated contained attributes hashable
titaiwangms May 27, 2025
be2c008
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
da05efb
delete previous_node and only delete the node
titaiwangms May 28, 2025
ce2bc54
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 29, 2025
1d4fd53
create and use a stateless function
titaiwangms May 29, 2025
5cfd94e
keep the names of graph output
titaiwangms May 29, 2025
44f6042
address reviews
titaiwangms May 29, 2025
ab212d6
resolve conflict
titaiwangms May 30, 2025
9c2d134
revert
titaiwangms May 30, 2025
6a43bfb
fix lint
titaiwangms May 30, 2025
3b1b19f
separate import common_subexpression_elimination
titaiwangms May 30, 2025
9fd8948
remove cse from optimizer
titaiwangms May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2796,7 +2796,7 @@ def __init__(
model_version: int | None = None,
doc_string: str | None = None,
functions: Sequence[Function] = (),
meta_data_props: dict[str, str] | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
self.graph: Graph = graph
self.ir_version = ir_version
Expand All @@ -2807,7 +2807,7 @@ def __init__(
self.doc_string = doc_string
self._functions = {func.identifier(): func for func in functions}
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props: dict[str, str] | None = meta_data_props
self._metadata_props: dict[str, str] | None = metadata_props

@property
def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
Expand Down
4 changes: 4 additions & 0 deletions onnxscript/ir/passes/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"AddInitializersToInputsPass",
"CheckerPass",
"ClearMetadataAndDocStringPass",
"CommonSubexpressionEliminationPass",
"InlinePass",
"LiftConstantsToInitializersPass",
"LiftSubgraphInitializersToMainGraphPass",
Expand All @@ -19,6 +20,9 @@
from onnxscript.ir.passes.common.clear_metadata_and_docstring import (
ClearMetadataAndDocStringPass,
)
from onnxscript.ir.passes.common.common_subexpression_elimination import (
CommonSubexpressionEliminationPass,
)
from onnxscript.ir.passes.common.constant_manipulation import (
AddInitializersToInputsPass,
LiftConstantsToInitializersPass,
Expand Down
138 changes: 138 additions & 0 deletions onnxscript/ir/passes/common/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Eliminate common subexpression in ONNX graphs."""

from __future__ import annotations

__all__ = [
"CommonSubexpressionEliminationPass",
]

import logging
from typing import Sequence

from onnxscript import ir

logger = logging.getLogger(__name__)


class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
"""Eliminate common subexpression in ONNX graphs."""

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Return the same ir.Model but with CSE applied to the graph."""
modified = False
graph = model.graph

modified = _eliminate_common_subexpression(graph, modified)

return ir.passes.PassResult(
model,
modified=modified,
)


def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
"""Eliminate common subexpression in ONNX graphs."""

# node to node identifier, length of outputs, inputs, and attributes
existing_node_info_to_the_node: dict[
tuple[
ir.OperatorIdentifier,
int, # len(outputs)
tuple[int, ...], # input ids
tuple[tuple[str, object], ...], # attributes
],
ir.Node,
] = {}

for node in graph:
# Skip control flow ops like Loop and If.
control_flow_op: bool = False
# Use equality to check if the node is a common subexpression.
attributes = {}
for k, v in node.attributes.items():
# TODO(exporter team): CSE subgraphs.
# NOTE: control flow ops like Loop and If won't be CSEd
# because attribute: graph won't match.
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
control_flow_op = True
logger.debug("Skipping control flow op %s", node)
# The attribute value could be directly taken from the original
# protobuf, so we need to make a copy of it.
value = v.value
if v.type in (
ir.AttributeType.INTS,
ir.AttributeType.FLOATS,
ir.AttributeType.STRINGS,
):
# For INT, FLOAT and STRING attributes, we convert them to tuples
# to ensure they are hashable.
value = tuple(value)
attributes[k] = value

if control_flow_op:
# If the node is a control flow op, we skip it.
continue

node_info = (
node.op_identifier(),
len(node.outputs),
tuple(id(input) for input in node.inputs),
tuple(sorted(attributes.items())),
)

# Check if the node is a common subexpression.
if node_info in existing_node_info_to_the_node:
# If it is, this node has an existing node with the same
# operator, number of outputs, inputs, and attributes.
# We replace the node with the existing node.
modified = True
existing_node = existing_node_info_to_the_node[node_info]
_remove_node_and_replace__values(
graph,
remove_nodes=[node],
remove_values=node.outputs,
new_values=existing_node.outputs,
)
logger.debug("Reusing node %s", existing_node)
else:
# If it is not, add to the mapping.
existing_node_info_to_the_node[node_info] = node
return modified


def _remove_node_and_replace__values(
graph: ir.Graph,
/,
remove_nodes: ir.Node,
remove_values: Sequence[ir.Value],
new_values: Sequence[ir.Value],
) -> None:
"""Replaces nodes and values in the graph or function.

Args:
graph: The graph to replace nodes and values in.
remove_nodes: The nodes to remove.
remove_values: The values to replace.
new_values: The values to replace with.
"""

for old_value, new_value in zip(remove_values, new_values):
# Propagate relevant info from old value to new value
# TODO(Rama): Perhaps this should be a separate utility function. Also, consider
# merging old and new type/shape info.
new_value.type = old_value.type
new_value.shape = old_value.shape
new_value.const_value = old_value.const_value
new_value.name = old_value.name

# Reconnect the users of the deleted values to use the new values
ir.convenience.replace_all_uses_with(remove_values, new_values)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(remove_values, new_values))
for idx, graph_or_function_output in enumerate(graph.outputs):
if graph_or_function_output in replacement_mapping:
graph.outputs[idx] = replacement_mapping[graph_or_function_output]

Check warning on line 136 in onnxscript/ir/passes/common/common_subexpression_elimination.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination.py#L136

Added line #L136 was not covered by tests

graph.remove(remove_nodes, safe=True)
Loading
Loading