Skip to content

Remove legacy optimizer #2180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
aebc6b0
Remove legacy optimizer
justinchuby Apr 10, 2025
9b8e955
test
justinchuby Apr 10, 2025
1bb22a0
Fix test-cases
gramalingam Apr 11, 2025
e8fa66a
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 11, 2025
cafdcd8
wip tests
justinchuby Apr 11, 2025
5fbcd1a
test
justinchuby Apr 11, 2025
11d183b
test
justinchuby Apr 11, 2025
23ae61a
typing
justinchuby Apr 11, 2025
67a107b
Implement final identity folding
justinchuby Apr 11, 2025
e0bec82
Update onnxscript/optimizer/_constant_folding.py
justinchuby Apr 11, 2025
23d6718
Update onnxscript/optimizer/_optimizer.py
justinchuby Apr 11, 2025
52b0656
test
justinchuby Apr 12, 2025
a3c657f
_inliner
justinchuby Apr 12, 2025
711ff70
Fix tests
justinchuby Apr 12, 2025
2a484e6
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 12, 2025
1540661
Handle edge case
justinchuby Apr 12, 2025
2a76657
comment
justinchuby Apr 12, 2025
95e216e
test
justinchuby Apr 12, 2025
5b7f9c5
Fix subgraph
justinchuby Apr 12, 2025
915c1b0
fix
justinchuby Apr 12, 2025
d28e62b
d
justinchuby Apr 12, 2025
3358372
refactor
justinchuby Apr 12, 2025
05c4941
fix
justinchuby Apr 12, 2025
c4814d9
docs
justinchuby Apr 12, 2025
0665303
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 14, 2025
2753b46
Fix test
justinchuby Apr 15, 2025
60f0c1b
Fix tests
justinchuby Apr 15, 2025
ae8438a
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 19, 2025
aadd644
lint
justinchuby Apr 19, 2025
5291e85
Merge branch 'main' into justinchu/remove-legacy
justinchuby Apr 21, 2025
c1e8a40
skip tests for now
justinchuby Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 78 additions & 15 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,109 @@
# Licensed under the MIT License.
from __future__ import annotations

from typing import TypeVar

__all__ = [
"fold_constants",
"fold_constants_ir",
"remove_unused_nodes",
"optimize",
"optimize_ir",
"basic_constant_propagation",
"fold_constants_ir",
"fold_constants",
"inline",
"optimize_ir",
"optimize",
"remove_unused_nodes",
]

import onnx

import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer._constant_folding as constant_folding
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding
from onnxscript import ir
from onnxscript.optimizer._constant_folding import (
basic_constant_propagation,
)
from onnxscript.optimizer._constant_folding import (
fold_constants as fold_constants_ir,
)
from onnxscript.optimizer._inliner import inline
from onnxscript.optimizer._optimizer import optimize_ir

basic_constant_propagation = constant_folding.basic_constant_propagation
fold_constants_ir = constant_folding.fold_constants
_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model)


def optimize(
model: _ModelProtoOrIr,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT,
output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT,
inline: bool = True,
) -> _ModelProtoOrIr:
"""Optimizes a model.

def optimize(model: ir.Model, *args, **kwargs) -> ir.Model:
Args:
model: The model to be optimized.
num_iterations: Number of times the optimization loop is repeated.
onnx_shape_inference: Applies node-level shape-inference as part of optimization
input_size_limit: Will not apply constant folding to ops with any input of size
greater than this. Does not apply to special ops like Shape() and Size().
output_size_limit: Will not rewrite any foldable-op into a Constant op if the size
of the output tensor is greater than this.
stop_if_no_change: Not supported currently (has no effect). Meant to stop the
outer optimization loop if no change is detected in one iteration.
inline: If True, inlines all functions in the model.

Returns:
The optimized model. If the input was a ModelProto, the output will also be a
ModelProto. If the input was an ir.Model, the output will also be an ir.Model.
"""
if isinstance(model, ir.Model):
# In that case, this is done inplace.
optimize_ir(model, *args, **kwargs)
# In this case, optimize is done inplace.
# TODO(justinchuby): Maybe make functional
optimize_ir(
model,
num_iterations=num_iterations,
onnx_shape_inference=onnx_shape_inference,
stop_if_no_change=stop_if_no_change,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
inline=inline,
)
return model
else:
return legacy_optimizer.optimize(model, *args, **kwargs)
assert isinstance(model, onnx.ModelProto)
model_ir = ir.serde.deserialize_model(model)
optimize_ir(
model_ir,
num_iterations=num_iterations,
onnx_shape_inference=onnx_shape_inference,
stop_if_no_change=stop_if_no_change,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
inline=inline,
)
# Move the model back to the proto
new_proto = ir.serde.serialize_model(model_ir)
return new_proto


def fold_constants(
model: ir.Model | onnx.ModelProto, *args, **kwargs
) -> constant_folding.FoldConstantsResult | bool:
) -> constant_folding.FoldConstantsResult:
"""Fold constants in a model in place."""
if isinstance(model, ir.Model):
return constant_folding.fold_constants(model, *args, **kwargs)
else:
return legacy_constant_folding.fold_constants(model, *args, **kwargs)
assert isinstance(model, onnx.ModelProto)
model_proto = model
model = ir.serde.deserialize_model(model_proto)
result = constant_folding.fold_constants(model, *args, **kwargs)
# Move the model back to the proto
new_proto = ir.serde.serialize_model(model)
model_proto.Clear()
model_proto.CopyFrom(new_proto)
return result


def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
Expand Down
27 changes: 22 additions & 5 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
e,
)

def new_constant(self, node: ir.Node, value):
def new_constant(self, node: ir.Node, value) -> ir.Node | None:
irvalue = node.outputs[0]
if not isinstance(value, np.ndarray):
# ONNX does not have a way to represent non-tensor constants, eg. a sequence.
Expand Down Expand Up @@ -965,7 +965,7 @@ def new_constant(self, node: ir.Node, value):
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node

def process_node(self, node: ir.Node):
def process_node(self, node: ir.Node) -> Replacement | None:
for i, value in enumerate(node.inputs):
sym_value = self._state.get_sym_value(value)
if isinstance(sym_value, ir.Value):
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def convert(av):
)
return None

def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function):
def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None:
logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name)

ir.convenience.replace_nodes_and_values(
Expand All @@ -1066,13 +1066,13 @@ def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None:
for graph in attr.as_graphs():
self.visit_graph(graph)

def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function):
def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None:
replacement = self.process_node(node)
if replacement is None:
# No change. Process attributes.
for attr in node.attributes.values():
self.visit_attribute(attr)
return None
return
else:
self.replace_node(node, replacement, root)

Expand All @@ -1087,6 +1087,23 @@ def visit_graph(self, graph: ir.Graph) -> None:
for node in graph:
self.visit_node(node, graph)

# Replace outputs if output nodes can be folded. This is typically outputs from
# Identity nodes
for i, output in enumerate(graph.outputs):
if output is None:
continue
sym_value = self._state.get_sym_value(output)
if not isinstance(sym_value, ir.Value):
# An output must be a Value
continue
if sym_value in graph.inputs:
# ONNX does not allow a graph output to be a graph input
continue
# Rename sym_value to match the output name
sym_value.name = output.name
graph.outputs[i] = sym_value
self.modified = True

self._state.pop_initializer_inputs()

def visit_function(self, function: ir.Function) -> None:
Expand Down
19 changes: 4 additions & 15 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,14 @@
import onnxscript.optimizer as optimizer
from onnxscript.ir import serde
from onnxscript.optimizer import _constant_folding
from onnxscript.optimizer._legacy import constant_folding


@parameterized.parameterized_class(("using_ir",), [(False,), (True,)])
class FoldConstantsTest(unittest.TestCase):
def _fold(self, model: onnx.ModelProto, onnx_shape_inference=False):
if self.using_ir:
ir_model = serde.deserialize_model(model)
_constant_folding.fold_constants(
ir_model, onnx_shape_inference=onnx_shape_inference
)
optimizer.remove_unused_nodes(ir_model)
return serde.serialize_model(ir_model)
else:
constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference)
optimizer.remove_unused_nodes(model)
return model
ir_model = serde.deserialize_model(model)
_constant_folding.fold_constants(ir_model, onnx_shape_inference=onnx_shape_inference)
optimizer.remove_unused_nodes(ir_model)
return serde.serialize_model(ir_model)

def test_fold_add(self):
model = onnx.parser.parse_model(
Expand Down Expand Up @@ -167,7 +158,6 @@ def test_fold_if_propagate(self):
"""
)
optimized = self._fold(model)
print(onnx.printer.to_text(optimized))
self.assertEqual(len(optimized.graph.node), 2)
self.assertEqual(optimized.graph.node[0].output[0], "m_square")
self.assertEqual(optimized.graph.node[0].op_type, "Constant")
Expand Down Expand Up @@ -245,7 +235,6 @@ def test_shape_inference(self):
"""
)
optimized = self._fold(model, onnx_shape_inference=True)
print(onnx.printer.to_text(optimized))
self.assertEqual(len(optimized.graph.node), 2)
self.assertEqual(optimized.graph.node[0].output[0], "C")

Expand Down
Loading
Loading