Skip to content

Remove legacy optimizer #2180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
aebc6b0
Remove legacy optimizer
justinchuby Apr 10, 2025
9b8e955
test
justinchuby Apr 10, 2025
1bb22a0
Fix test-cases
gramalingam Apr 11, 2025
e8fa66a
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 11, 2025
cafdcd8
wip tests
justinchuby Apr 11, 2025
5fbcd1a
test
justinchuby Apr 11, 2025
11d183b
test
justinchuby Apr 11, 2025
23ae61a
typing
justinchuby Apr 11, 2025
67a107b
Implement final identity folding
justinchuby Apr 11, 2025
e0bec82
Update onnxscript/optimizer/_constant_folding.py
justinchuby Apr 11, 2025
23d6718
Update onnxscript/optimizer/_optimizer.py
justinchuby Apr 11, 2025
52b0656
test
justinchuby Apr 12, 2025
a3c657f
_inliner
justinchuby Apr 12, 2025
711ff70
Fix tests
justinchuby Apr 12, 2025
2a484e6
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 12, 2025
1540661
Handle edge case
justinchuby Apr 12, 2025
2a76657
comment
justinchuby Apr 12, 2025
95e216e
test
justinchuby Apr 12, 2025
5b7f9c5
Fix subgraph
justinchuby Apr 12, 2025
915c1b0
fix
justinchuby Apr 12, 2025
d28e62b
d
justinchuby Apr 12, 2025
3358372
refactor
justinchuby Apr 12, 2025
05c4941
fix
justinchuby Apr 12, 2025
c4814d9
docs
justinchuby Apr 12, 2025
0665303
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 14, 2025
2753b46
Fix test
justinchuby Apr 15, 2025
60f0c1b
Fix tests
justinchuby Apr 15, 2025
ae8438a
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 19, 2025
aadd644
lint
justinchuby Apr 19, 2025
5291e85
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 21, 2025
c1e8a40
skip tests for now
justinchuby Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 77 additions & 15 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,90 @@
# Licensed under the MIT License.
from __future__ import annotations

from typing import TypeVar

__all__ = [
"fold_constants",
"fold_constants_ir",
"remove_unused_nodes",
"optimize",
"optimize_ir",
"basic_constant_propagation",
"fold_constants_ir",
"fold_constants",
"inline",
"optimize_ir",
"optimize",
"remove_unused_nodes",
]

import onnx

import onnxscript.ir.passes.common.inliner
import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer._constant_folding as constant_folding
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding
from onnxscript import ir
from onnxscript.optimizer._constant_folding import (
basic_constant_propagation,
)
from onnxscript.optimizer._constant_folding import (
fold_constants as fold_constants_ir,
)
from onnxscript.optimizer._optimizer import optimize_ir

basic_constant_propagation = constant_folding.basic_constant_propagation
fold_constants_ir = constant_folding.fold_constants
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)


def optimize(
model: _ModelProtoOrIr,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
inline: bool = True,
) -> _ModelProtoOrIr:
"""Optimizes a model.

def optimize(model: ir.Model, *args, **kwargs) -> ir.Model:
Args:
model: The model to be optimized.
num_iterations: Number of times the optimization loop is repeated.
onnx_shape_inference: Applies node-level shape-inference as part of optimization
input_size_limit: Will not apply constant folding to ops with any input of size
greater than this. Does not apply to special ops like Shape() and Size().
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
of the output tensor is greater than this.
stop_if_no_change: Stop the optimization loop if no change is detected in an iteration.
inline: If True, inlines all functions in the model.

Returns:
The optimized model. If the input was a ModelProto, the output will also be a
ModelProto. If the input was an ir.Model, the output will also be an ir.Model.
"""
if isinstance(model, ir.Model):
# In that case, this is done inplace.
optimize_ir(model, *args, **kwargs)
# In this case, optimize is done inplace.
# TODO(justinchuby): Maybe make functional
optimize_ir(
model,
num_iterations=num_iterations,
onnx_shape_inference=onnx_shape_inference,
stop_if_no_change=stop_if_no_change,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
inline=inline,
)
return model
else:
return legacy_optimizer.optimize(model, *args, **kwargs)
assert isinstance(model, onnx.ModelProto)
model_ir = ir.serde.deserialize_model(model)
optimize_ir(
model_ir,
num_iterations=num_iterations,
onnx_shape_inference=onnx_shape_inference,
stop_if_no_change=stop_if_no_change,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
inline=inline,
)
# Move the model back to the proto
new_proto = ir.serde.serialize_model(model_ir)
return new_proto


def inline(model: ir.Model) -> None:
Expand All @@ -43,11 +96,20 @@ def inline(model: ir.Model) -> None:

def fold_constants(
model: ir.Model | onnx.ModelProto, *args, **kwargs
) -> constant_folding.FoldConstantsResult | bool:
) -> constant_folding.FoldConstantsResult:
"""Fold constants in a model in place."""
if isinstance(model, ir.Model):
return constant_folding.fold_constants(model, *args, **kwargs)
else:
return legacy_constant_folding.fold_constants(model, *args, **kwargs)
assert isinstance(model, onnx.ModelProto)
model_proto = model
model = ir.serde.deserialize_model(model_proto)
result = constant_folding.fold_constants(model, *args, **kwargs)
# Move the model back to the proto
new_proto = ir.serde.serialize_model(model)
model_proto.Clear()
model_proto.CopyFrom(new_proto)
return result


def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
Expand Down
44 changes: 39 additions & 5 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
e,
)

def new_constant(self, node: ir.Node, value):
def new_constant(self, node: ir.Node, value) -> ir.Node | None:
irvalue = node.outputs[0]
if not isinstance(value, np.ndarray):
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
Expand Down Expand Up @@ -965,7 +965,7 @@ def new_constant(self, node: ir.Node, value):
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node

def process_node(self, node: ir.Node):
def process_node(self, node: ir.Node) -> Replacement | None:
for i, value in enumerate(node.inputs):
sym_value = self._state.get_sym_value(value)
if isinstance(sym_value, ir.Value):
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def convert(av):
)
return None

def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function):
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:
logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)

ir.convenience.replace_nodes_and_values(
Expand All @@ -1066,13 +1066,13 @@ def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None:
for graph in attr.as_graphs():
self.visit_graph(graph)

def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function):
def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None:
replacement = self.process_node(node)
if replacement is None:
# No change. Process attributes.
for attr in node.attributes.values():
self.visit_attribute(attr)
return None
return
else:
self.replace_node(node, replacement, root)

Expand All @@ -1087,6 +1087,22 @@ def visit_graph(self, graph: ir.Graph) -> None:
for node in graph:
self.visit_node(node, graph)

# Replace outputs if output nodes can be folded. This are typically outputs from
# Identity nodes
for i, output in enumerate(graph.outputs):
if output is None:
continue
sym_value = self._state.get_sym_value(output)
if not isinstance(sym_value, ir.Value):
# An output must be a Value
continue
if not _sym_value_can_replace_graph_output(graph, sym_value, output):
continue
# Rename sym_value to match the output name
sym_value.name = output.name
graph.outputs[i] = sym_value
self.modified = True

self._state.pop_initializer_inputs()

def visit_function(self, function: ir.Function) -> None:
Expand All @@ -1103,6 +1119,24 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map)


def _sym_value_can_replace_graph_output(
graph: ir.Graph, sym_value: ir.Value, output: ir.Value
) -> bool:
if (producer := sym_value.producer()) is None:
# If the sym_value has no producer, it is some graph's input
# ONNX does not allow a graph input to be a graph output
return False
if producer.graph is not graph:
# The sym_value must be produced by a node in the graph to be an output of this graph
return False
if sym_value.is_graph_output():
# If the sym_value is already an output of a graph, we cannot rename it
# to this output name. Otherwise the graph output represented by sym_value
# will lose its name.
return False
return True


@dataclasses.dataclass
class FoldConstantsResult(ir.passes.PassResult):
symbolic_value_map: dict[ir.Value, SymbolicValue]
Expand Down
Loading
Loading