Skip to content

Always fold the Transpose node in the constant folder #2355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 105 additions & 73 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import math
import typing
from typing import Any, Callable, Iterable, Sequence, Union
from typing import Any, Callable, Collection, Iterable, Sequence, Union

import numpy as np
import onnx
Expand All @@ -24,12 +24,7 @@
DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 1024 * 1024


def is_control_flow_op(node: ir.Node) -> bool:
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
return any(attr.type in graph_types for attr in node.attributes.values())


non_deterministic_ops = frozenset(
_NON_DETERMINISTIC_OPS = frozenset(
{
"RandomUniform",
"RandomNormal",
Expand All @@ -40,21 +35,21 @@
)


def is_non_deterministic_op(node: ir.Node) -> bool:
return node.op_type in non_deterministic_ops and utils.is_onnx_domain(node.domain)
logger = logging.getLogger(__name__)


def is_onnx_op(node: ir.Node, op_type: str) -> bool:
return node.op_type == op_type and utils.is_onnx_domain(node.domain)
def _is_control_flow_op(node: ir.Node) -> bool:
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
return any(attr.type in graph_types for attr in node.attributes.values())


def is_constant_op(node: ir.Node) -> bool:
return node.op_type in {"Constant", "ConstantOfShape"} and utils.is_onnx_domain(
node.domain
)
def _is_non_deterministic_op(node: ir.Node) -> bool:
return node.op_type in _NON_DETERMINISTIC_OPS and utils.is_onnx_domain(node.domain)


logger = logging.getLogger(__name__)
def _is_onnx_op(node: ir.Node, op_type: str) -> bool:
return node.op_type == op_type and utils.is_onnx_domain(node.domain)


# "Standard" evaluators are used to perform constant-folding.
# The API below works only for non-control-flow ops (ops without any graph-attributes).
Expand Down Expand Up @@ -168,19 +163,6 @@
def set_sym_value(self, value: ir.Value, sym_value: SymbolicValue) -> None:
self._sym_value_map[value] = sym_value

def push_initializer_inputs(self) -> None:
self._initializer_inputs.append(set())

def pop_initializer_inputs(self) -> None:
self._initializer_inputs.pop()

def add_initializer_input(self, value: ir.Value) -> None:
assert self._initializer_inputs
self._initializer_inputs[-1].add(value)

def is_initializer_input(self, value: ir.Value) -> bool:
return any(value in inputs for inputs in self._initializer_inputs)

def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
if const_value is not None:
Expand Down Expand Up @@ -301,6 +283,11 @@
array = const_value.numpy().view(const_value.dtype.numpy())
except FileNotFoundError:
# External data is not available.
logger.warning(

Check warning on line 286 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L286

Added line #L286 was not covered by tests
"External data for value '%s' is not available. "
"This may lead to incorrect constant folding.",
val.name,
)
return None
assert isinstance(array, np.ndarray)
return array
Expand Down Expand Up @@ -841,28 +828,48 @@


class FoldConstantsPass(ir.passes.InPlacePass):
"""A pass that folds constant expressions in the model.

Attributes:
shape_inference: Whether to perform shape inference.
input_size_limit: Maximum size of input tensors to fold.
output_size_limit: Maximum size of output tensors to fold.
always_fold_ops: Collection of op types that should always be folded.
For ops from the default opset, only op_type is neede (e.g. "Transpose"),
otherwise specify the domain with ``{domain}::{op_type}``.
"""

def __init__(
self,
*,
shape_inference: bool,
input_size_limit: int,
output_size_limit: int,
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
) -> None:
self._shape_inference = shape_inference
self._input_size_limit = input_size_limit
self._output_size_limit = output_size_limit
self.opset_imports: dict[str, int] = {}
self.counts: dict[str, int] = {}
self.sizes: dict[str, int] = {}
self.modified: bool = False
self.shape_inference = shape_inference
self.input_size_limit = input_size_limit
self.output_size_limit = output_size_limit
ops = []
for name in always_fold_ops:
domain, op_type = name.split("::", 1) if "::" in name else ("", name)
if domain == "ai.onnx":
domain = ""

Check warning on line 857 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L857

Added line #L857 was not covered by tests
ops.append((domain, op_type))
self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops)

self._opset_imports: dict[str, int] = {}
self._counts: dict[str, int] = {}
self._sizes: dict[str, int] = {}
self._modified: bool = False
self._state = OptimizerState()
self._reset()

def _reset(self) -> None:
"""Reset internal states for a new run."""
self.counts = {}
self.sizes = {}
self.modified = False
self._counts = {}
self._sizes = {}
self._modified = False
self._state = OptimizerState()

def _do_inference(self, node: ir.Node) -> None:
Expand Down Expand Up @@ -896,7 +903,7 @@
# TODO: pass in constant values, ir_version
try:
schema = onnx.defs.get_schema(
node.op_type, self.opset_imports[node.domain], node.domain
node.op_type, self._opset_imports[node.domain], node.domain
)
output_types = onnx.shape_inference.infer_node_outputs(
schema,
Expand Down Expand Up @@ -937,7 +944,7 @@
tensor.name = irvalue.name
irvalue.const_value = tensor

if value.nbytes > self._output_size_limit:
if value.nbytes > self.output_size_limit:
# Handle examples like Transpose(weight) to be folded even if the size is large,
# as long as weight has no other uses. This won't increase model size.
removed_input_size = 0
Expand Down Expand Up @@ -967,6 +974,7 @@
return node

def process_node(self, node: ir.Node) -> Replacement | None:
"""Process a node and return a Replacement if the node can be replaced."""
for i, value in enumerate(node.inputs):
sym_value = self._state.get_sym_value(value)
if isinstance(sym_value, ir.Value):
Expand All @@ -977,16 +985,16 @@
sym_value.name,
)
node.replace_input_with(i, sym_value)
self.modified = True
self._modified = True
# TODO(rama): consider merging type/other info from both values

# Do incremental shape inference
if self._shape_inference and not is_control_flow_op(node):
if self.shape_inference and not _is_control_flow_op(node):
self._do_inference(node)

if node.domain not in self.opset_imports:
if node.domain not in self._opset_imports:
return None
version = self.opset_imports[node.domain]
version = self._opset_imports[node.domain]
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
for optimizer in op_optimizers:
assert optimizer
Expand All @@ -999,31 +1007,58 @@
output = [output]
return Replacement(output, context.nodes)

if is_control_flow_op(node) or is_non_deterministic_op(node):
if _is_control_flow_op(node) or _is_non_deterministic_op(node):
return None

if is_onnx_op(node, "Constant"):
if _is_onnx_op(node, "Constant"):
_process_constant_node(node)
return None

input_values = [_get_numpy_value(x) for x in node.inputs]
if any(x is None for x in input_values):
return None

if any(self._state.is_initializer_input(x) for x in node.inputs): # type: ignore[arg-type]
if any(x.is_graph_input() for x in node.inputs if x is not None):
# Do not fold any graph inputs to preserve graph signature
return None

if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr]
# Ensure all node inputs are constants
if any(x.const_value is None for x in node.inputs if x is not None):
if logger.isEnabledFor(logging.DEBUG):
input_sizes = [input.size for input in input_values] # type: ignore[union-attr]
logger.debug(
"Skipping constant folding for op %s due to large input size: %s",
node.op_type,
input_sizes,
"Skipping constant folding for node %s because it has non-constant inputs",
node,
[x.name for x in node.inputs if x is not None],
)
return None

# Filter out bfloat16 cases?
input_tensors = [x.const_value if x is not None else None for x in node.inputs]

if any(
tensor.nbytes > self.input_size_limit
for tensor in input_tensors
if tensor is not None
):
if (node.domain, node.op_type) in self.always_fold_ops and all(
len(input.consumers()) == 1 for input in node.inputs if input is not None
):
# If the op is in always_fold_ops and all inputs are used only by this node,
# we can still fold it even if the input size exceeds the limit.
logger.debug(
"Folding large constant for node %s because it is in the always_fold_ops list",
node,
)
else:
# Skip folding large tensors
if logger.isEnabledFor(logging.DEBUG):
input_sizes = [

Check warning on line 1050 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L1050

Added line #L1050 was not covered by tests
tensor.nbytes for tensor in input_tensors if tensor is not None
]
logger.debug(

Check warning on line 1053 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L1053

Added line #L1053 was not covered by tests
"Skipping constant folding for node %s due to large input size: %s",
node,
input_sizes,
)
return None

input_values = [_get_numpy_value(x) for x in node.inputs]

def convert(av):
if av.type == ir.AttributeType.TENSOR:
return ir.serde.serialize_tensor(av.value)
Expand All @@ -1038,7 +1073,7 @@
return None
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
replacement = self.new_constant(node, outputs)
if is_onnx_op(node, "ConstantOfShape") or replacement is None:
if _is_onnx_op(node, "ConstantOfShape") or replacement is None:
return None
return Replacement(replacement.outputs, [replacement])
else:
Expand All @@ -1054,7 +1089,7 @@
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)

self.modified = True
self._modified = True

# TODO: what about new opset_imports?
# TODO: track statistics about replaced nodes and sizes of new constants
Expand All @@ -1079,13 +1114,6 @@
self.replace_node(node, replacement, root)

def visit_graph(self, graph: ir.Graph) -> None:
# Track inputs that have a const_value (which is really a default-value, and should not
# be used for constant-folding).
self._state.push_initializer_inputs()
for input in graph.inputs:
if input.const_value is not None:
self._state.add_initializer_input(input)

for node in graph:
self.visit_node(node, graph)

Expand All @@ -1103,22 +1131,20 @@
# Rename sym_value to match the output name
sym_value.name = output.name
graph.outputs[i] = sym_value
self.modified = True

self._state.pop_initializer_inputs()
self._modified = True

def visit_function(self, function: ir.Function) -> None:
for node in function:
self.visit_node(node, function)

def call(self, model: ir.Model) -> ir.passes.PassResult:
def call(self, model: ir.Model) -> FoldConstantsResult:
self._reset()
self.opset_imports = model.opset_imports
self._opset_imports = model.opset_imports
self.visit_graph(model.graph)
for function in model.functions.values():
# TODO(rama): Should we specialize functions?
self.visit_function(function)
return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map)
return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map)


def _sym_value_can_replace_graph_output(
Expand Down Expand Up @@ -1155,6 +1181,7 @@
onnx_shape_inference: bool = False,
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
) -> FoldConstantsResult:
"""
Applies constant folding optimization to the model.
Expand All @@ -1169,6 +1196,10 @@
output_size_limit: The maximum size (in bytes) of output tensors
that can be stored after constant folding. Defaults to
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
always_fold_ops: A collection of op types that should always be folded,
regardless of their input or output sizes. For ops from the default opset,
only op_type is neede (e.g. "Transpose"), otherwise specify the domain
with ``{domain}::{op_type}``.

Returns:
An instance of `FoldConstantsResult`.
Expand All @@ -1178,5 +1209,6 @@
shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
always_fold_ops=always_fold_ops,
)
return folder_pass(model) # type: ignore[return-value]
Loading
Loading