Skip to content

[rewriter] Transpose initializer rule #2255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
llama_rule_sets,
no_op,
pattern,
transpose_initializer,
)

_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)
Expand All @@ -32,6 +33,7 @@
*cast_constant_of_shape.rules.rules,
*collapse_slices.rules.rules,
*llama_rule_sets.llama_p0_rule_set().rules,
transpose_initializer.rule,
)


Expand Down
48 changes: 46 additions & 2 deletions onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import onnxscript.rewriter._matcher as _matcher
import onnxscript.rewriter._pattern_ir as _pattern_ir
from onnxscript import ir
from onnxscript.ir import _convenience, _tape
from onnxscript.ir import _tape

T = TypeVar("T")

Expand Down Expand Up @@ -81,6 +81,50 @@ def _update_opset_imports(
)


def _replace_nodes_and_values(
graph_or_function: ir.Graph | ir.Function,
/,
insertion_point: ir.Node,
old_nodes: Sequence[ir.Node],
new_nodes: Sequence[ir.Node],
old_values: Sequence[ir.Value],
new_values: Sequence[ir.Value],
) -> None:
"""Replaces nodes and values in the graph or function.

Args:
graph_or_function: The graph or function to replace nodes and values in.
insertion_point: The node to insert the new nodes after.
old_nodes: The nodes to replace.
new_nodes: The nodes to replace with.
old_values: The values to replace.
new_values: The values to replace with.
"""

for old_value, new_value in zip(old_values, new_values):
# Propagate relevant info from old value to new value
if new_value.type is None:
new_value.type = old_value.type
if new_value.shape is None:
new_value.shape = old_value.shape
if new_value.const_value is None:
new_value.const_value = old_value.const_value
if new_value.name is None:
new_value.name = old_value.name

# Reconnect the users of the deleted values to use the new values
ir.convenience.replace_all_uses_with(old_values, new_values)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]

# insert new nodes after the index node
graph_or_function.insert_after(insertion_point, new_nodes)
graph_or_function.remove(old_nodes, safe=True)


class RewriteRule:
def __init__(
self,
Expand Down Expand Up @@ -525,7 +569,7 @@ def _apply_to_graph_or_function(
)
f = ir.Function(domain, name, overload, graph=graph, attributes=())
model.functions[f.identifier()] = f
_convenience.replace_nodes_and_values(
_replace_nodes_and_values(
graph_or_function,
node,
delta.match.nodes if rule.remove_nodes else [],
Expand Down
63 changes: 63 additions & 0 deletions onnxscript/rewriter/transpose_initializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rules to collapse Transpose nodes into initializers."""

from __future__ import annotations

import logging

import numpy as np

from onnxscript import ir
from onnxscript.rewriter import _ir_utils as ir_utils
from onnxscript.rewriter import pattern as orp

logger = logging.getLogger(__name__)


class TransposeInitializer(orp.RewriteRuleClassBase):
"""Folds Transpose nodes into initializers."""

def __init__(self):
super().__init__("TransposeInitializer", remove_nodes=True)

def pattern(self, op, initializer):
return op.Transpose(initializer, _allow_other_attributes=True)

def rewrite(self, op, initializer: ir.Value) -> ir.Value:
original_transpose = initializer.consumers()[0]
perm_attr = original_transpose.attributes.get("perm")

if perm_attr is not None:
perm = perm_attr.as_ints()
else:
perm = None

Check warning on line 34 in onnxscript/rewriter/transpose_initializer.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L34

Added line #L34 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we eliminate that case in def check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When perm is None the transpose can still be evaluated. So I don't think it can be eliminated in check()?


array = ir_utils.get_numpy_value(initializer)
if array is None:
# Do nothing
logger.debug("Failed to obtain the initializer value. Do nothing")

Check warning on line 39 in onnxscript/rewriter/transpose_initializer.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L39

Added line #L39 was not covered by tests
# perm=None is filtered out when the attribute is constructed so we are ok
return op.Transpose(initializer, perm=perm_attr)

Check warning on line 41 in onnxscript/rewriter/transpose_initializer.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L41

Added line #L41 was not covered by tests

# np.transpose does not create a copy. So we don't need to use LazyTensors.
transposed = np.transpose(array, axes=perm)
new_name = f"{initializer.name}_transposed"
return op.initializer(ir.tensor(transposed, name=new_name))

def check(self, context, initializer: ir.Value) -> orp.MatchResult:
del context # Unused
check_result = orp.MatchResult()
if not initializer.is_initializer():
return check_result.fail("Value is not an initializer")
if initializer.is_graph_input():
return check_result.fail("Value is a graph input")

Check warning on line 54 in onnxscript/rewriter/transpose_initializer.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L54

Added line #L54 was not covered by tests
if initializer.const_value is None:
return check_result.fail("Value.const_value is None")

Check warning on line 56 in onnxscript/rewriter/transpose_initializer.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/transpose_initializer.py#L56

Added line #L56 was not covered by tests
if len(initializer.uses()) != 1:
return check_result.fail("Initializer is used by more than one node")
# TODO(justinchuby): Avoid matching when it is a graph input
return check_result


rule = TransposeInitializer.rule()
Loading