Skip to content

Always fold the Transpose node in the constant folder #2355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 79 additions & 57 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import math
import typing
from typing import Any, Callable, Iterable, Sequence, Union
from typing import Any, Callable, Collection, Iterable, Sequence, Union

import numpy as np
import onnx
Expand All @@ -24,12 +24,7 @@
DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 1024 * 1024


def is_control_flow_op(node: ir.Node) -> bool:
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
return any(attr.type in graph_types for attr in node.attributes.values())


non_deterministic_ops = frozenset(
_NON_DETERMINISTIC_OPS = frozenset(
{
"RandomUniform",
"RandomNormal",
Expand All @@ -40,21 +35,21 @@ def is_control_flow_op(node: ir.Node) -> bool:
)


def is_non_deterministic_op(node: ir.Node) -> bool:
return node.op_type in non_deterministic_ops and utils.is_onnx_domain(node.domain)
logger = logging.getLogger(__name__)


def is_onnx_op(node: ir.Node, op_type: str) -> bool:
return node.op_type == op_type and utils.is_onnx_domain(node.domain)
def _is_control_flow_op(node: ir.Node) -> bool:
graph_types = {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
return any(attr.type in graph_types for attr in node.attributes.values())


def is_constant_op(node: ir.Node) -> bool:
return node.op_type in {"Constant", "ConstantOfShape"} and utils.is_onnx_domain(
node.domain
)
def _is_non_deterministic_op(node: ir.Node) -> bool:
return node.op_type in _NON_DETERMINISTIC_OPS and utils.is_onnx_domain(node.domain)


logger = logging.getLogger(__name__)
def _is_onnx_op(node: ir.Node, op_type: str) -> bool:
return node.op_type == op_type and utils.is_onnx_domain(node.domain)


# "Standard" evaluators are used to perform constant-folding.
# The API below works only for non-control-flow ops (ops without any graph-attributes).
Expand Down Expand Up @@ -301,6 +296,11 @@ def _get_numpy_value(
array = const_value.numpy().view(const_value.dtype.numpy())
except FileNotFoundError:
# External data is not available.
logger.warning(
"External data for value '%s' is not available. "
"This may lead to incorrect constant folding.",
val.name,
)
return None
assert isinstance(array, np.ndarray)
return array
Expand Down Expand Up @@ -841,28 +841,40 @@ def merge_dims(dim1, dim2):


class FoldConstantsPass(ir.passes.InPlacePass):
"""A pass that folds constant expressions in the model.

Attributes:
shape_inference: Whether to perform shape inference.
input_size_limit: Maximum size of input tensors to fold.
output_size_limit: Maximum size of output tensors to fold.
always_fold_ops: Collection of op types that should always be folded.
"""

def __init__(
self,
*,
shape_inference: bool,
input_size_limit: int,
output_size_limit: int,
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
) -> None:
self._shape_inference = shape_inference
self._input_size_limit = input_size_limit
self._output_size_limit = output_size_limit
self.opset_imports: dict[str, int] = {}
self.counts: dict[str, int] = {}
self.sizes: dict[str, int] = {}
self.modified: bool = False
self.shape_inference = shape_inference
self.input_size_limit = input_size_limit
self.output_size_limit = output_size_limit
self.always_fold_ops: frozenset[str] = frozenset(always_fold_ops)

self._opset_imports: dict[str, int] = {}
self._counts: dict[str, int] = {}
self._sizes: dict[str, int] = {}
self._modified: bool = False
self._state = OptimizerState()
self._reset()

def _reset(self) -> None:
"""Reset internal states for a new run."""
self.counts = {}
self.sizes = {}
self.modified = False
self._counts = {}
self._sizes = {}
self._modified = False
self._state = OptimizerState()

def _do_inference(self, node: ir.Node) -> None:
Expand Down Expand Up @@ -896,7 +908,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
# TODO: pass in constant values, ir_version
try:
schema = onnx.defs.get_schema(
node.op_type, self.opset_imports[node.domain], node.domain
node.op_type, self._opset_imports[node.domain], node.domain
)
output_types = onnx.shape_inference.infer_node_outputs(
schema,
Expand Down Expand Up @@ -937,7 +949,7 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
tensor.name = irvalue.name
irvalue.const_value = tensor

if value.nbytes > self._output_size_limit:
if value.nbytes > self.output_size_limit:
# Handle examples like Transpose(weight) to be folded even if the size is large,
# as long as weight has no other uses. This won't increase model size.
removed_input_size = 0
Expand Down Expand Up @@ -967,6 +979,7 @@ def new_constant(self, node: ir.Node, value) -> ir.Node | None:
return node

def process_node(self, node: ir.Node) -> Replacement | None:
"""Process a node and return a Replacement if the node can be replaced."""
for i, value in enumerate(node.inputs):
sym_value = self._state.get_sym_value(value)
if isinstance(sym_value, ir.Value):
Expand All @@ -977,16 +990,16 @@ def process_node(self, node: ir.Node) -> Replacement | None:
sym_value.name,
)
node.replace_input_with(i, sym_value)
self.modified = True
self._modified = True
# TODO(rama): consider merging type/other info from both values

# Do incremental shape inference
if self._shape_inference and not is_control_flow_op(node):
if self.shape_inference and not _is_control_flow_op(node):
self._do_inference(node)

if node.domain not in self.opset_imports:
if node.domain not in self._opset_imports:
return None
version = self.opset_imports[node.domain]
version = self._opset_imports[node.domain]
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
for optimizer in op_optimizers:
assert optimizer
Expand All @@ -999,31 +1012,45 @@ def process_node(self, node: ir.Node) -> Replacement | None:
output = [output]
return Replacement(output, context.nodes)

if is_control_flow_op(node) or is_non_deterministic_op(node):
if _is_control_flow_op(node) or _is_non_deterministic_op(node):
return None

if is_onnx_op(node, "Constant"):
if _is_onnx_op(node, "Constant"):
_process_constant_node(node)
return None

input_values = [_get_numpy_value(x) for x in node.inputs]
if any(x is None for x in input_values):
if any(x.is_graph_input() for x in node.inputs if x is not None):
# Do not fold any graph inputs to preserve graph signature
return None

if any(self._state.is_initializer_input(x) for x in node.inputs): # type: ignore[arg-type]
# Ensure all node inputs are constants
if any(x.const_value is None for x in node.inputs if x is not None):
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Skipping constant folding for node %s because it has None constant inputs",
node,
[x.name for x in node.inputs if x is not None],
)
return None

if any(input.nbytes > self._input_size_limit for input in input_values): # type: ignore[union-attr]
input_tensors = [x.const_value if x is not None else None for x in node.inputs]

if node.op_type not in self.always_fold_ops and any(
tensor.nbytes > self.input_size_limit
for tensor in input_tensors
if tensor is not None
):
if logger.isEnabledFor(logging.DEBUG):
input_sizes = [input.size for input in input_values] # type: ignore[union-attr]
input_sizes = [tensor.nbytes for tensor in input_tensors if tensor is not None]
logger.debug(
"Skipping constant folding for op %s due to large input size: %s",
node.op_type,
"Skipping constant folding for node %s due to large input size: %s",
node,
input_sizes,
)
return None

# Filter out bfloat16 cases?
input_values = [_get_numpy_value(x) for x in node.inputs]

def convert(av):
if av.type == ir.AttributeType.TENSOR:
return ir.serde.serialize_tensor(av.value)
Expand All @@ -1038,7 +1065,7 @@ def convert(av):
return None
if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)):
replacement = self.new_constant(node, outputs)
if is_onnx_op(node, "ConstantOfShape") or replacement is None:
if _is_onnx_op(node, "ConstantOfShape") or replacement is None:
return None
return Replacement(replacement.outputs, [replacement])
else:
Expand All @@ -1054,7 +1081,7 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function)
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)

self.modified = True
self._modified = True

# TODO: what about new opset_imports?
# TODO: track statistics about replaced nodes and sizes of new constants
Expand All @@ -1079,13 +1106,6 @@ def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None:
self.replace_node(node, replacement, root)

def visit_graph(self, graph: ir.Graph) -> None:
# Track inputs that have a const_value (which is really a default-value, and should not
# be used for constant-folding).
self._state.push_initializer_inputs()
for input in graph.inputs:
if input.const_value is not None:
self._state.add_initializer_input(input)

for node in graph:
self.visit_node(node, graph)

Expand All @@ -1103,22 +1123,20 @@ def visit_graph(self, graph: ir.Graph) -> None:
# Rename sym_value to match the output name
sym_value.name = output.name
graph.outputs[i] = sym_value
self.modified = True

self._state.pop_initializer_inputs()
self._modified = True

def visit_function(self, function: ir.Function) -> None:
for node in function:
self.visit_node(node, function)

def call(self, model: ir.Model) -> ir.passes.PassResult:
def call(self, model: ir.Model) -> FoldConstantsResult:
self._reset()
self.opset_imports = model.opset_imports
self._opset_imports = model.opset_imports
self.visit_graph(model.graph)
for function in model.functions.values():
# TODO(rama): Should we specialize functions?
self.visit_function(function)
return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map)
return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map)


def _sym_value_can_replace_graph_output(
Expand Down Expand Up @@ -1155,6 +1173,7 @@ def fold_constants(
onnx_shape_inference: bool = False,
input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
always_fold_ops: Collection[str] = frozenset(["Transpose"]),
) -> FoldConstantsResult:
"""
Applies constant folding optimization to the model.
Expand All @@ -1169,6 +1188,8 @@ def fold_constants(
output_size_limit: The maximum size (in bytes) of output tensors
that can be stored after constant folding. Defaults to
`DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`.
always_fold_ops: A collection of op types that should always be folded,
regardless of their input or output sizes.

Returns:
An instance of `FoldConstantsResult`.
Expand All @@ -1178,5 +1199,6 @@ def fold_constants(
shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
always_fold_ops=always_fold_ops,
)
return folder_pass(model) # type: ignore[return-value]
45 changes: 33 additions & 12 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,31 +536,52 @@ def test_gather_symdim(self):
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_large_transpose(self):
def test_input_size_limit(self):
model_text = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[M, 256] x) => (float[M, 256] z)
<float[1, 1] w = {1.0}> # placeholder for large initializer of shape [256, 256]
{
w_squared = Mul (w, w)
z = Add (x, w_squared)
}
"""
model = ir.from_onnx_text(model_text)
w = model.graph.initializers["w"]
w.shape = ir.Shape([256, 256])
w.const_value = ir.tensor(np.random.random((256, 256)).astype(np.float32))

# Input size limit will prevent folding of Mul op
optimized = self._fold(model, input_size_limit=3 * 256 * 256)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Mul", "Add"])

# Input size limit will allow folding of Mul op
# Since there is no increase in model-size, output-size is not a concern.
optimized = self._fold(
model, input_size_limit=4 * 256 * 256, output_size_limit=4 * 256 * 256
)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant", "Add"])

def test_transpose_is_always_folded(self):
model_text = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[M, 256] x) => (float[M, 512] z)
<float[1, 1] w = {1.0}> # placeholder for large initializer of shape [512, 256]
{
wt = Transpose (w)
z = MatMul (x, wt)
z = Transpose (w)
}
"""
model = ir.from_onnx_text(model_text)
w = model.graph.initializers["w"]
w.shape = ir.Shape([512, 256])
w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32))

# Input size limit will prevent folding of Transpose op
optimized = self._fold(model, input_size_limit=3 * 512 * 256)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Transpose", "MatMul"])

# Input size limit will allow folding of Transpose op
# Since there is no increase in model-size, output-size is not a concern.
optimized = self._fold(model, input_size_limit=4 * 512 * 256)
# Input size limit will not prevent folding of Transpose op
optimized = self._fold(model, input_size_limit=1)
ops = [node.op_type for node in optimized.graph]
self.assertEqual(ops, ["Constant", "MatMul"])
self.assertEqual(ops, ["Constant"])

def test_multi_graph_identity_output_preserves_output_name(self):
model = """
Expand Down
Loading