-
Notifications
You must be signed in to change notification settings - Fork 72
Support common subexpression elimination pass (CSE) #2304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
f6563ad
draft tests
titaiwangms 2f55000
add more tests
titaiwangms d567ba1
Merge branch 'main' into titaiwang/cse_pass
titaiwangms d1082d1
Update onnxscript/ir/passes/common/common_subexpression_elimination.py
titaiwangms ea873e4
Merge branch 'main' into titaiwang/cse_pass
titaiwangms 017ef27
inplace
titaiwangms 2a370e4
add recursive function but one test is still faling
titaiwangms d490072
Merge branch 'main' into titaiwang/cse_pass
titaiwangms 706b86a
revert subgraph cse support
titaiwangms dcbc08d
add another test for subgraph
titaiwangms 55d32c7
add the pass to optimization
titaiwangms c5cab5b
make repeated contained attributes hashable
titaiwangms be2c008
Merge branch 'main' into titaiwang/cse_pass
titaiwangms da05efb
delete previous_node and only delete the node
titaiwangms ce2bc54
Merge branch 'main' into titaiwang/cse_pass
titaiwangms 1d4fd53
create and use a stateless function
titaiwangms 5cfd94e
keep the names of graph output
titaiwangms 44f6042
address reviews
titaiwangms ab212d6
resolve conflict
titaiwangms 9c2d134
revert
titaiwangms 6a43bfb
fix lint
titaiwangms 3b1b19f
separate import common_subexpression_elimination
titaiwangms 9fd8948
remove cse from optimizer
titaiwangms File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
138 changes: 138 additions & 0 deletions
138
onnxscript/ir/passes/common/common_subexpression_elimination.py
gramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Eliminate common subexpression in ONNX graphs.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"CommonSubexpressionEliminationPass", | ||
] | ||
|
||
import logging | ||
from typing import Sequence | ||
|
||
from onnxscript import ir | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass): | ||
"""Eliminate common subexpression in ONNX graphs.""" | ||
|
||
def call(self, model: ir.Model) -> ir.passes.PassResult: | ||
"""Return the same ir.Model but with CSE applied to the graph.""" | ||
modified = False | ||
graph = model.graph | ||
|
||
modified = _eliminate_common_subexpression(graph, modified) | ||
|
||
return ir.passes.PassResult( | ||
model, | ||
modified=modified, | ||
) | ||
|
||
|
||
def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool: | ||
"""Eliminate common subexpression in ONNX graphs.""" | ||
|
||
# node to node identifier, length of outputs, inputs, and attributes | ||
existing_node_info_to_the_node: dict[ | ||
tuple[ | ||
ir.OperatorIdentifier, | ||
int, # len(outputs) | ||
tuple[int, ...], # input ids | ||
tuple[tuple[str, object], ...], # attributes | ||
], | ||
ir.Node, | ||
] = {} | ||
|
||
for node in graph: | ||
# Skip control flow ops like Loop and If. | ||
control_flow_op: bool = False | ||
# Use equality to check if the node is a common subexpression. | ||
attributes = {} | ||
for k, v in node.attributes.items(): | ||
# TODO(exporter team): CSE subgraphs. | ||
# NOTE: control flow ops like Loop and If won't be CSEd | ||
# because attribute: graph won't match. | ||
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS): | ||
control_flow_op = True | ||
logger.debug("Skipping control flow op %s", node) | ||
# The attribute value could be directly taken from the original | ||
# protobuf, so we need to make a copy of it. | ||
value = v.value | ||
if v.type in ( | ||
ir.AttributeType.INTS, | ||
ir.AttributeType.FLOATS, | ||
ir.AttributeType.STRINGS, | ||
): | ||
# For INT, FLOAT and STRING attributes, we convert them to tuples | ||
# to ensure they are hashable. | ||
value = tuple(value) | ||
attributes[k] = value | ||
|
||
if control_flow_op: | ||
# If the node is a control flow op, we skip it. | ||
continue | ||
|
||
node_info = ( | ||
node.op_identifier(), | ||
len(node.outputs), | ||
tuple(id(input) for input in node.inputs), | ||
tuple(sorted(attributes.items())), | ||
) | ||
|
||
# Check if the node is a common subexpression. | ||
if node_info in existing_node_info_to_the_node: | ||
# If it is, this node has an existing node with the same | ||
# operator, number of outputs, inputs, and attributes. | ||
# We replace the node with the existing node. | ||
modified = True | ||
existing_node = existing_node_info_to_the_node[node_info] | ||
_remove_node_and_replace__values( | ||
graph, | ||
remove_nodes=[node], | ||
remove_values=node.outputs, | ||
new_values=existing_node.outputs, | ||
) | ||
logger.debug("Reusing node %s", existing_node) | ||
else: | ||
# If it is not, add to the mapping. | ||
existing_node_info_to_the_node[node_info] = node | ||
return modified | ||
|
||
|
||
def _remove_node_and_replace__values( | ||
graph: ir.Graph, | ||
/, | ||
remove_nodes: ir.Node, | ||
remove_values: Sequence[ir.Value], | ||
new_values: Sequence[ir.Value], | ||
) -> None: | ||
"""Replaces nodes and values in the graph or function. | ||
|
||
Args: | ||
graph: The graph to replace nodes and values in. | ||
remove_nodes: The nodes to remove. | ||
remove_values: The values to replace. | ||
new_values: The values to replace with. | ||
""" | ||
|
||
for old_value, new_value in zip(remove_values, new_values): | ||
# Propagate relevant info from old value to new value | ||
# TODO(Rama): Perhaps this should be a separate utility function. Also, consider | ||
# merging old and new type/shape info. | ||
new_value.type = old_value.type | ||
new_value.shape = old_value.shape | ||
new_value.const_value = old_value.const_value | ||
new_value.name = old_value.name | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Reconnect the users of the deleted values to use the new values | ||
ir.convenience.replace_all_uses_with(remove_values, new_values) | ||
# Update graph/function outputs if the node generates output | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
replacement_mapping = dict(zip(remove_values, new_values)) | ||
for idx, graph_or_function_output in enumerate(graph.outputs): | ||
if graph_or_function_output in replacement_mapping: | ||
graph.outputs[idx] = replacement_mapping[graph_or_function_output] | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
graph.remove(remove_nodes, safe=True) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.