diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index b073b3345e..a6e8ea2fc5 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -2,14 +2,16 @@ # Licensed under the MIT License. from __future__ import annotations +from typing import TypeVar + __all__ = [ - "fold_constants", - "fold_constants_ir", - "remove_unused_nodes", - "optimize", - "optimize_ir", "basic_constant_propagation", + "fold_constants_ir", + "fold_constants", "inline", + "optimize_ir", + "optimize", + "remove_unused_nodes", ] import onnx @@ -17,22 +19,73 @@ import onnxscript.ir.passes.common.inliner import onnxscript.ir.passes.common.unused_removal import onnxscript.optimizer._constant_folding as constant_folding -import onnxscript.optimizer._legacy._optimizer as legacy_optimizer -import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding from onnxscript import ir +from onnxscript.optimizer._constant_folding import ( + basic_constant_propagation, +) +from onnxscript.optimizer._constant_folding import ( + fold_constants as fold_constants_ir, +) from onnxscript.optimizer._optimizer import optimize_ir -basic_constant_propagation = constant_folding.basic_constant_propagation -fold_constants_ir = constant_folding.fold_constants +_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) + +def optimize( + model: _ModelProtoOrIr, + num_iterations: int = 2, + *, + onnx_shape_inference: bool = True, + stop_if_no_change: bool = True, + input_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, + output_size_limit: int = constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, + inline: bool = True, +) -> _ModelProtoOrIr: + """Optimizes a model. -def optimize(model: ir.Model, *args, **kwargs) -> ir.Model: + Args: + model: The model to be optimized. + num_iterations: Number of times the optimization loop is repeated. + onnx_shape_inference: Applies node-level shape-inference as part of optimization + input_size_limit: Will not apply constant folding to ops with any input of size + greater than this. Does not apply to special ops like Shape() and Size(). + output_size_limit: Will not rewrite any foldable-op into a Constant op if the size + of the output tensor is greater than this. + stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. + inline: If True, inlines all functions in the model. + + Returns: + The optimized model. If the input was a ModelProto, the output will also be a + ModelProto. If the input was an ir.Model, the output will also be an ir.Model. + """ if isinstance(model, ir.Model): - # In that case, this is done inplace. - optimize_ir(model, *args, **kwargs) + # In this case, optimize is done inplace. + # TODO(justinchuby): Maybe make functional + optimize_ir( + model, + num_iterations=num_iterations, + onnx_shape_inference=onnx_shape_inference, + stop_if_no_change=stop_if_no_change, + input_size_limit=input_size_limit, + output_size_limit=output_size_limit, + inline=inline, + ) return model else: - return legacy_optimizer.optimize(model, *args, **kwargs) + assert isinstance(model, onnx.ModelProto) + model_ir = ir.serde.deserialize_model(model) + optimize_ir( + model_ir, + num_iterations=num_iterations, + onnx_shape_inference=onnx_shape_inference, + stop_if_no_change=stop_if_no_change, + input_size_limit=input_size_limit, + output_size_limit=output_size_limit, + inline=inline, + ) + # Move the model back to the proto + new_proto = ir.serde.serialize_model(model_ir) + return new_proto def inline(model: ir.Model) -> None: @@ -43,11 +96,20 @@ def inline(model: ir.Model) -> None: def fold_constants( model: ir.Model | onnx.ModelProto, *args, **kwargs -) -> constant_folding.FoldConstantsResult | bool: +) -> constant_folding.FoldConstantsResult: + """Fold constants in a model in place.""" if isinstance(model, ir.Model): return constant_folding.fold_constants(model, *args, **kwargs) else: - return legacy_constant_folding.fold_constants(model, *args, **kwargs) + assert isinstance(model, onnx.ModelProto) + model_proto = model + model = ir.serde.deserialize_model(model_proto) + result = constant_folding.fold_constants(model, *args, **kwargs) + # Move the model back to the proto + new_proto = ir.serde.serialize_model(model) + model_proto.Clear() + model_proto.CopyFrom(new_proto) + return result def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index e8db6450dd..cce74cb132 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -919,7 +919,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: e, ) - def new_constant(self, node: ir.Node, value): + def new_constant(self, node: ir.Node, value) -> ir.Node | None: irvalue = node.outputs[0] if not isinstance(value, np.ndarray): # ONNX does not have a way to represent non-tensor constants, eg. a sequence. @@ -965,7 +965,7 @@ def new_constant(self, node: ir.Node, value): node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) return node - def process_node(self, node: ir.Node): + def process_node(self, node: ir.Node) -> Replacement | None: for i, value in enumerate(node.inputs): sym_value = self._state.get_sym_value(value) if isinstance(sym_value, ir.Value): @@ -1046,7 +1046,7 @@ def convert(av): ) return None - def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function): + def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None: logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) ir.convenience.replace_nodes_and_values( @@ -1066,13 +1066,13 @@ def visit_attribute(self, attr: ir.Attr | ir.RefAttr) -> None: for graph in attr.as_graphs(): self.visit_graph(graph) - def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function): + def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None: replacement = self.process_node(node) if replacement is None: # No change. Process attributes. for attr in node.attributes.values(): self.visit_attribute(attr) - return None + return else: self.replace_node(node, replacement, root) @@ -1087,6 +1087,22 @@ def visit_graph(self, graph: ir.Graph) -> None: for node in graph: self.visit_node(node, graph) + # Replace outputs if output nodes can be folded. This are typically outputs from + # Identity nodes + for i, output in enumerate(graph.outputs): + if output is None: + continue + sym_value = self._state.get_sym_value(output) + if not isinstance(sym_value, ir.Value): + # An output must be a Value + continue + if not _sym_value_can_replace_graph_output(graph, sym_value, output): + continue + # Rename sym_value to match the output name + sym_value.name = output.name + graph.outputs[i] = sym_value + self.modified = True + self._state.pop_initializer_inputs() def visit_function(self, function: ir.Function) -> None: @@ -1103,6 +1119,24 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: return FoldConstantsResult(model, self.modified, self._state.symbolic_value_map) +def _sym_value_can_replace_graph_output( + graph: ir.Graph, sym_value: ir.Value, output: ir.Value +) -> bool: + if (producer := sym_value.producer()) is None: + # If the sym_value has no producer, it is some graph's input + # ONNX does not allow a graph input to be a graph output + return False + if producer.graph is not graph: + # The sym_value must be produced by a node in the graph to be an output of this graph + return False + if sym_value.is_graph_output(): + # If the sym_value is already an output of a graph, we cannot rename it + # to this output name. Otherwise the graph output represented by sym_value + # will lose its name. + return False + return True + + @dataclasses.dataclass class FoldConstantsResult(ir.passes.PassResult): symbolic_value_map: dict[ir.Value, SymbolicValue] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 8738dd0de9..81ed911c9e 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -7,33 +7,32 @@ import numpy as np import onnx import parameterized -import pytest -import onnxscript.ir as ir import onnxscript.optimizer as optimizer -from onnxscript.ir import serde +from onnxscript import ir from onnxscript.optimizer import _constant_folding -from onnxscript.optimizer._legacy import constant_folding -@parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) +def _create_model(model_text: str) -> ir.Model: + """Create a model from the given text.""" + model = onnx.parser.parse_model(model_text) + return ir.serde.deserialize_model(model) + + class FoldConstantsTest(unittest.TestCase): - def _fold(self, model: onnx.ModelProto, onnx_shape_inference=False): - if self.using_ir: - ir_model = serde.deserialize_model(model) - _constant_folding.fold_constants( - ir_model, onnx_shape_inference=onnx_shape_inference - ) - optimizer.remove_unused_nodes(ir_model) - return serde.serialize_model(ir_model) - else: - constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference) - optimizer.remove_unused_nodes(model) - return model + def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs): + if isinstance(model, str): + model = _create_model(model) + _constant_folding.fold_constants( + model, onnx_shape_inference=onnx_shape_inference, **kwargs + ) + optimizer.remove_unused_nodes(model) + # Ensure the model is valid after optimization + onnx.checker.check_model(ir.serde.serialize_model(model)) + return model def test_fold_add(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (float[N] x) => (float[N] z) { two = Constant () @@ -41,14 +40,13 @@ def test_fold_add(self): z = Mul(x, four) } """ - ) + optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "four") + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "four") def test_fold_cast_like(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (float[N] x) => (float[N] z) { two = Constant () @@ -57,14 +55,13 @@ def test_fold_cast_like(self): z = Mul(x, four) } """ - ) + optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "four") + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "four") def test_fold_shape(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (float[16, 16] x) => (float[16, 16] z) { shape = Shape(x) @@ -74,14 +71,13 @@ def test_fold_shape(self): z = Mul(x, four) } """ - ) + optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "four") + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "four") def test_fold_shape_slice(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (float[M, N, 16, 16] x) => (float[M, N, 16, 16] z) { shape = Shape (x) @@ -91,14 +87,13 @@ def test_fold_shape_slice(self): z = Mul(x, four) } """ - ) + optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "four") + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "four") def test_fold_if_cond(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (float[16, 16] x) => (float[16, 16] z) { shape = Shape(x) @@ -112,15 +107,14 @@ def test_fold_if_cond(self): > } """ - ) + optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 1) - self.assertEqual(optimized.graph.node[0].output[0], "z") - self.assertEqual(optimized.graph.node[0].op_type, "Mul") + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph[0].outputs[0].name, "z") + self.assertEqual(optimized.graph[0].op_type, "Mul") def test_fold_inside_if_branch(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { two = Constant () @@ -138,17 +132,16 @@ def test_fold_inside_if_branch(self): > } """ - ) + optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 1) - then_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "then_branch") - self.assertEqual(len(then_graph.node), 2) - else_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "else_branch") - self.assertEqual(len(else_graph.node), 2) + self.assertEqual(len(optimized.graph), 1) + then_graph = optimized.graph[0].attributes["then_branch"].as_graph() + self.assertEqual(len(then_graph), 2) + else_graph = optimized.graph[0].attributes["else_branch"].as_graph() + self.assertEqual(len(else_graph), 2) def test_fold_if_propagate(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (float[16, 16] x) => (float[16, 16] z) { shape = Shape(x) @@ -165,16 +158,14 @@ def test_fold_if_propagate(self): z = Mul (x, m_square) } """ - ) + optimized = self._fold(model) - print(onnx.printer.to_text(optimized)) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "m_square") - self.assertEqual(optimized.graph.node[0].op_type, "Constant") + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "m_square") + self.assertEqual(optimized.graph[0].op_type, "Constant") def test_fold_redundant_cast(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (float[N] x) => (float[N] z) { two = Constant () @@ -182,48 +173,27 @@ def test_fold_redundant_cast(self): z = Mul(x_cast, two) } """ - ) + optimized = self._fold(model, onnx_shape_inference=True) - self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph), 2) def test_fold_redundant_cast2(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (float[N] x) => (float[N] z) { two = Constant () z = CastLike(x, two) } """ - ) + optimized = self._fold(model, onnx_shape_inference=True) - self.assertEqual(len(optimized.graph.node), 1) - self.assertEqual(optimized.graph.node[0].op_type, "Identity") - self.assertEqual(optimized.graph.node[0].output[0], "z") - self.assertEqual(optimized.graph.node[0].input[0], "x") - - @pytest.mark.skip(reason="Feature removed to catch errors early") - def test_fold_undefined_vars(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - four = Add(two, two) - y = Shape(t1) - w = CastLike(x, t2) - w2 = CastLike(t3, t4) - w3 = Size(t5) - z = Sum (four, y, w, w2, w3) - } - """ - ) - # No optimizations expected. Just make sure it doesn't crash. - optimized = self._fold(model, onnx_shape_inference=False) - self.assertEqual(len(optimized.graph.node), 6) + self.assertEqual(len(optimized.graph), 1) + self.assertEqual(optimized.graph[0].op_type, "Identity") + self.assertEqual(optimized.graph[0].outputs[0].name, "z") + self.assertEqual(optimized.graph[0].inputs[0].name, "x") def test_shape_inference(self): - model = onnx.parser.parse_model( - """ + model = """ agraph (int64[64] x) => (int64[N] z) { one = Constant () @@ -243,22 +213,20 @@ def test_shape_inference(self): z = Mul(x, C) } """ - ) + optimized = self._fold(model, onnx_shape_inference=True) - print(onnx.printer.to_text(optimized)) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(optimized.graph.node[0].output[0], "C") + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(optimized.graph[0].outputs[0].name, "C") def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split( self, ): - model = onnx.parser.parse_model( - """ + model = """ < ir_version: 8, opset_import: ["" : 18] > -func (float[1,512] x) => ( return_val) { +func (float[1,512] x) => (float[1,512] return_val) { int64_128 = Constant () splits = SplitToSequence (x, int64_128) int64_0 = Constant () @@ -270,47 +238,43 @@ def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_ int64_3 = Constant () split_3 = SequenceAt (splits, int64_3) return_val = Concat (split_0, split_1, split_2, split_3) -} - """ - ) +}""" # TODO: There is an unrelated limitation that `symbolic_value` is not # utilized when the value is only referenced by graph output. # E.g., the following test model will not have this optimization # applied. - """ -< - ir_version: 8, - opset_import: ["" : 18] -> -func (float[1,512] x) => ( split_0, split_1, split_2, split_3) { - int64_128 = Constant () - splits = SplitToSequence (x, int64_128) - int64_0 = Constant () - split_0 = SequenceAt (splits, int64_0) - int64_1 = Constant () - split_1 = SequenceAt (splits, int64_1) - int64_2 = Constant () - split_2 = SequenceAt (splits, int64_2) - int64_3 = Constant () - split_3 = SequenceAt (splits, int64_3) -} - """ + # + # < + # ir_version: 8, + # opset_import: ["" : 18] + # > + # func (float[1,512] x) => ( split_0, split_1, split_2, split_3) { + # int64_128 = Constant () + # splits = SplitToSequence (x, int64_128) + # int64_0 = Constant () + # split_0 = SequenceAt (splits, int64_0) + # int64_1 = Constant () + # split_1 = SequenceAt (splits, int64_1) + # int64_2 = Constant () + # split_2 = SequenceAt (splits, int64_2) + # int64_3 = Constant () + # split_3 = SequenceAt (splits, int64_3) + # } optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 2) - self.assertEqual(len(optimized.graph.node[-2].output), 4) - self.assertEqual(optimized.graph.node[-2].op_type, "Split") + self.assertEqual(len(optimized.graph), 2) + self.assertEqual(len(optimized.graph[-2].outputs), 4) + self.assertEqual(optimized.graph[-2].op_type, "Split") def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split( self, ): - model = onnx.parser.parse_model( - """ + model = """ < ir_version: 8, opset_import: ["" : 18] > -func (float[1,512] x) => ( return_val) { +func (float[1,512] x) => (float[1,N] return_val) { const = Constant () splits = SplitToSequence (x, const) int64_0 = Constant () @@ -320,24 +284,22 @@ def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_sp int64_2 = Constant () split_2 = SequenceAt (splits, int64_2) return_val = Concat (split_0, split_1, split_2) -} - """ - ) +}""" + optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 3) - self.assertEqual(len(optimized.graph.node[-2].output), 3) - self.assertEqual(optimized.graph.node[-2].op_type, "Split") + self.assertEqual(len(optimized.graph), 3) + self.assertEqual(len(optimized.graph[-2].outputs), 3) + self.assertEqual(optimized.graph[-2].op_type, "Split") def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze( self, ): - model = onnx.parser.parse_model( - """ + model = """ < ir_version: 8, opset_import: ["" : 18] > -func (float[1,3] x) => ( return_val) { +func (float[1,3] x) => (float[1,3] return_val) { const = Constant () splits = SplitToSequence (x, const) int64_0 = Constant () @@ -347,20 +309,17 @@ def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_ int64_2 = Constant () split_2 = SequenceAt (splits, int64_2) return_val = Concat (split_0, split_1, split_2) -} - """ - ) +}""" optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 7) - self.assertEqual(len(optimized.graph.node[1].output), 3) - self.assertEqual(optimized.graph.node[1].op_type, "Split") - self.assertEqual(len([n for n in optimized.graph.node if n.op_type == "Squeeze"]), 3) + self.assertEqual(len(optimized.graph), 7) + self.assertEqual(len(optimized.graph[1].outputs), 3) + self.assertEqual(optimized.graph[1].op_type, "Split") + self.assertEqual(len([n for n in optimized.graph if n.op_type == "Squeeze"]), 3) def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( self, ): - model = onnx.parser.parse_model( - """ + model = """ < ir_version: 8, opset_import: ["" : 18] @@ -369,19 +328,16 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( const = Constant () splits = SplitToSequence (x, const) return_val = ConcatFromSequence (splits) -} - """ - ) +}""" + optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 3) - self.assertEqual(optimized.graph.node[2].op_type, "Concat") - onnx.checker.check_model(optimized) + self.assertEqual(len(optimized.graph), 3) + self.assertEqual(optimized.graph[2].op_type, "Concat") def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( self, ): - model = onnx.parser.parse_model( - """ + model = """ < ir_version: 8, opset_import: ["" : 18] @@ -390,24 +346,11 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( const = Constant () splits = SplitToSequence (x, const) return_val = ConcatFromSequence (splits) -} - """ - ) - optimized = self._fold(model) - self.assertEqual(len(optimized.graph.node), 7) - self.assertEqual(optimized.graph.node[6].op_type, "Concat") - onnx.checker.check_model(optimized) - +}""" -class FoldConstantsIrTest(unittest.TestCase): - def _fold(self, model: str | onnx.ModelProto | ir.Model, **kwargs) -> ir.Model: - if isinstance(model, str): - model = onnx.parser.parse_model(model) - if isinstance(model, onnx.ModelProto): - model = serde.deserialize_model(model) - _constant_folding.fold_constants(model, **kwargs) - optimizer.remove_unused_nodes(model) - return model + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 7) + self.assertEqual(optimized.graph[6].op_type, "Concat") def test_initializer_input_not_folded(self): model_text = """ @@ -417,8 +360,7 @@ def test_initializer_input_not_folded(self): # c is not a constant, and following should not be folded. two_c = Add (c, c) z = Mul (x, two_c) - } - """ + }""" optimized = self._fold(model_text) self.assertEqual(len(optimized.graph), 2) self.assertEqual(optimized.graph.node(0).op_type, "Add") @@ -601,7 +543,7 @@ def test_gather_symdim(self): self.assertEqual(optimized.graph.node(-1).op_type, "Identity") def test_large_transpose(self): - model = """ + model_text = """ agraph (float[M, 256] x) => (float[M, 512] z) # placeholder for large initializer of shape [512, 256] @@ -610,22 +552,38 @@ def test_large_transpose(self): z = MatMul (x, wt) } """ - irmodel = serde.deserialize_model(onnx.parser.parse_model(model)) - w = irmodel.graph.initializers["w"] + model = _create_model(model_text) + w = model.graph.initializers["w"] w.shape = ir.Shape([512, 256]) w.const_value = ir.tensor(np.random.random((512, 256)).astype(np.float32)) # Input size limit will prevent folding of Transpose op - optimized = self._fold(irmodel, input_size_limit=3 * 512 * 256) + optimized = self._fold(model, input_size_limit=3 * 512 * 256) ops = [node.op_type for node in optimized.graph] self.assertEqual(ops, ["Transpose", "MatMul"]) # Input size limit will allow folding of Transpose op # Since there is no increase in model-size, output-size is not a concern. - optimized = self._fold(irmodel, input_size_limit=4 * 512 * 256) + optimized = self._fold(model, input_size_limit=4 * 512 * 256) ops = [node.op_type for node in optimized.graph] self.assertEqual(ops, ["Constant", "MatMul"]) + def test_multi_graph_identity_output_preserves_output_name(self): + model = """ + + agraph (float[N] x) => (float[N] graph_output1, float[N] graph_output2) { + t = Identity(x) + graph_output1 = Identity(t) + graph_output2 = Identity(t) + }""" + optimized = self._fold(model) + self.assertEqual(len(optimized.graph), 2) + self.assertEqual([n.op_type for n in optimized.graph], ["Identity", "Identity"]) + self.assertEqual( + [n.outputs[0].name for n in optimized.graph], ["graph_output1", "graph_output2"] + ) + self.assertEqual([input.name for input in optimized.graph.inputs], ["x"]) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/optimizer/_function_folding_test.py b/onnxscript/optimizer/_function_folding_test.py index 1d911bd911..5e7de8b0de 100644 --- a/onnxscript/optimizer/_function_folding_test.py +++ b/onnxscript/optimizer/_function_folding_test.py @@ -5,12 +5,18 @@ import onnx import onnxscript.testing -from onnxscript import optimizer +from onnxscript import ir, optimizer + + +def _create_model(model_text: str) -> ir.Model: + """Create a model from the given text.""" + model = onnx.parser.parse_model(model_text) + return ir.serde.deserialize_model(model) class FunctionFoldingTest(unittest.TestCase): def test_identity(self): - model = onnx.parser.parse_model( + model = _create_model( """ agraph (float[N] x1, bool cond1) => (float[N] z1) { @@ -32,19 +38,16 @@ def test_identity(self): > t4 = Add(t3, t3) z = Identity(t4) - } - """ + }""" ) optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, + model, onnx_shape_inference=False, num_iterations=1, inline=True ) self.assertEqual(len(optimized.functions), 0) - self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph), 2) def test_sequence_concat(self): - model = onnx.parser.parse_model( + model = _create_model( """ agraph (float[N] x1) => (float[M] z1) { @@ -55,21 +58,18 @@ def test_sequence_concat(self): t0 = Add (x, x) t2 = Add (x, x) t3 = SequenceConstruct (x, t0, t2, x) - z = ConcatFromSequence (t3) - } - """ + z = ConcatFromSequence (t3) + }""" ) optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, + model, onnx_shape_inference=False, num_iterations=1, inline=False ) - function_node = optimized.functions[0].node - self.assertEqual(len(function_node), 3) - self.assertEqual(function_node[2].op_type, "Concat") + function = optimized.functions[("local", "fun1", "")] + self.assertEqual(len(function), 3) + self.assertEqual(function[2].op_type, "Concat") def test_sequence_at(self): - model = onnx.parser.parse_model( + model = _create_model( """ agraph (float[N] x) => (float[M] z) { @@ -78,27 +78,25 @@ def test_sequence_at(self): s = SequenceConstruct (x, t0, t1) one = Constant () z = SequenceAt (s, one) - } - """ + }""" ) optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, + model, onnx_shape_inference=False, num_iterations=1, inline=False ) - expected = onnx.parser.parse_model( + expected = _create_model( """ agraph (float[N] x) => (float[M] z) { - t0 = Add (x, x) - z = Identity (t0) - } - """ + z = Add (x, x) + }""" + ) + # TODO(justinchuby): Implement assert_isomorphic_graph for IR objects + onnxscript.testing.assert_isomorphic_graph( + ir.to_proto(optimized.graph), ir.to_proto(expected.graph) ) - onnxscript.testing.assert_isomorphic_graph(optimized.graph, expected.graph) def test_single_user_function_is_modified_inplace_after_folding(self): - model = onnx.parser.parse_model( + model = _create_model( """ agraph (float[N] x1) => (float[M] z1) { @@ -110,84 +108,51 @@ def test_single_user_function_is_modified_inplace_after_folding(self): t2 = Add (x, x) t3 = SequenceConstruct (x, t0, t2, x) z = ConcatFromSequence (t3) - } - """ - ) - optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, - ) - self.assertEqual(optimized.functions[0].name, "fun1") - - def test_multi_users_function_is_not_modified_inplace_after_folding(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x1) => (float[M] z1, float[M] z2) { - z1 = local.fun1(x1) - z2 = local.fun1(x1) - } - - fun1 (x) => (z) { - t0 = Add (x, x) - t2 = Add (x, x) - t3 = SequenceConstruct (x, t0, t2, x) - z = ConcatFromSequence (t3) - } - """ + }""" ) optimized = optimizer.optimize( - model, - onnx_shape_inference=False, - num_iterations=1, + model, onnx_shape_inference=False, num_iterations=1, inline=False ) - self.assertEqual(len(optimized.functions), 2) - self.assertNotEqual(optimized.functions[0].name, "fun1") - self.assertNotEqual(optimized.functions[1].name, "fun1") + self.assertEqual(next(iter(optimized.functions.values())).name, "fun1") def test_fold_nested_if_function_succeeds(self): - model = onnx.parser.parse_model( + model = _create_model( """ -< - ir_version: 9, - opset_import: ["this" : 1, "" : 21] -> -func (float[1,512] x, float[1,512] y) => ( out) { - out = this.foldable_func (x, y) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable_func (x, y) => (z_6) -{ - cond = Constant () - z_6 = If (cond) ( z_2) { - cond_0 = Not (cond) - z_2 = If (cond_0) ( z) { - z = Add (x, x) - }, else_branch: graph = elseGraph_5 () => ( z_1) { - z_1 = Identity (x) - }> - }, else_branch: graph = elseGraph_4 () => ( z_5) { - z_5 = If (cond) ( z_3) { - z_3 = Add (y, y) - }, else_branch: graph = elseGraph_10 () => ( z_4) { - z_4 = Add (x, y) - }> - }> -} - """ - ) - optimized = optimizer.optimize( - model, - onnx_shape_inference=False, + < + ir_version: 9, + opset_import: ["this" : 1, "" : 18] + > + func (float[1,512] x, float[1,512] y) => ( out) { + out = this.foldable_func (x, y) + } + < + domain: "this", + opset_import: ["" : 18] + > + foldable_func (x, y) => (z_6) + { + cond = Constant () + z_6 = If (cond) ( z_2) { + cond_0 = Not (cond) + z_2 = If (cond_0) ( z) { + z = Add (x, x) + }, else_branch: graph = elseGraph_5 () => ( z_1) { + z_1 = Identity (x) + }> + }, else_branch: graph = elseGraph_4 () => ( z_5) { + z_5 = If (cond) ( z_3) { + z_3 = Add (y, y) + }, else_branch: graph = elseGraph_10 () => ( z_4) { + z_4 = Add (x, y) + }> + }> + }""" ) + optimized = optimizer.optimize(model, onnx_shape_inference=False, inline=True) self.assertEqual(len(optimized.functions), 0) - self.assertEqual(len(optimized.graph.node), 1) - self.assertNotIn("If", {n.op_type for n in optimized.graph.node}) + self.assertEqual(len(optimized.graph), 2) + self.assertNotIn("If", {n.op_type for n in optimized.graph}) if __name__ == "__main__": diff --git a/onnxscript/optimizer/_legacy/_optimizer.py b/onnxscript/optimizer/_legacy/_optimizer.py deleted file mode 100644 index 829eb9c25f..0000000000 --- a/onnxscript/optimizer/_legacy/_optimizer.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import logging -from typing import Any - -import onnx -import onnx.shape_inference - -import onnxscript.optimizer -from onnxscript import rewriter -from onnxscript.optimizer._legacy._simple_function_folding import ( - inline_functions_with_unused_outputs, - inline_simple_functions, -) -from onnxscript.optimizer._legacy.constant_folding import fold_constants - -logger = logging.getLogger(__name__) - - -def optimize( - model: onnx.ModelProto, - num_iterations: int = 2, - *, - onnx_shape_inference: bool = True, - stop_if_no_change: bool = True, - external_data_folder: str = "", - **kwargs: Any, -) -> onnx.ModelProto: - """Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc. - - Args: - model (onnx.ModelProto): The model to optimize. - num_iterations (int, optional): Number of iterations to perform. - onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model. - Set this to False to turn off onnx shape inference, and rely on model carried shapes and types. - This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries - the symbolic shapes recorded from dynamo tracing. - stop_if_no_change (bool, optional): Whether to stop if no change is detected. - external_data_folder (str, optional): The folder to store external data. - **kwargs: Additional keyword arguments. For BC purposes. - """ - if kwargs.pop("function_aware_folding", None) is not None: - logger.warning( - "'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. " - "To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. " - "This would turn off incremental onnx shape inference and rely on model carried shapes and types. " - "See 'onnx_shape_inference' for more details." - ) - for _ in range(num_iterations): - if onnx_shape_inference: - if model.ByteSize() < 1024 * 1024 * 1024 * 2: - # NOTE: strict mode is disabled because it crashes on the models - # that have different shapes inferred from the model carried shapes. - # The case can be found in: - # https://github.com/microsoft/onnxscript/issues/1443 - model = onnx.shape_inference.infer_shapes( - model, check_type=True, strict_mode=False, data_prop=True - ) - else: - logger.warning( - "The model size is too large for full model shape inference. " - "Skipping this step." - ) - - inline_simple_functions(model) - modified = fold_constants( - model, external_data_folder, onnx_shape_inference=onnx_shape_inference - ) - - onnxscript.optimizer.remove_unused_nodes(model) - inline_simple_functions(model) - onnxscript.optimizer.remove_unused_functions(model) - inline_functions_with_unused_outputs(model) - # NOTE: This is general rewrite rules - model = rewriter.rewrite(model) - if stop_if_no_change and not modified: - logger.debug("Stopping after %d iterations.", _) - break - - for node in model.graph.node: - logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name) - - for function in model.functions: - for node in function.node: - logger.debug( - "Function %s::%s node %s::%s name %s.", - function.domain, - function.name, - node.domain, - node.op_type, - node.name, - ) - - return model diff --git a/onnxscript/optimizer/_legacy/_remove_unused_proto.py b/onnxscript/optimizer/_legacy/_remove_unused_proto.py deleted file mode 100644 index 78dbf49b5b..0000000000 --- a/onnxscript/optimizer/_legacy/_remove_unused_proto.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import logging -from typing import Sequence - -import onnx -from google.protobuf.internal.containers import ( # type: ignore - RepeatedCompositeFieldContainer, -) - -logger = logging.getLogger(__name__) - - -def remove_unused_optional_outputs( - n: onnx.NodeProto, used: set, opset_import: Sequence[onnx.OperatorSetIdProto] -) -> None: - try: - if n.domain not in {"", "onnx.ai"}: - return - onnx_opset_version = 1 - for opset in opset_import: - if opset.domain == n.domain: - onnx_opset_version = opset.version - op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain) - except Exception: - return - - if n.op_type == "BatchNormalization": - # BatchNormalization op has 3 outputs: Y, running_mean, running_var - # If running_mean and running_var are not used, remove them, and the training_mode attribute - def is_used_output(i: int) -> bool: - if i < len(n.output): - return n.output[i] in used - return False - - if is_used_output(1) or is_used_output(2): - return - del n.output[1:] - for j, attr in enumerate(n.attribute): - if attr.name == "training_mode": - del n.attribute[j] - break - - optional_info = [] - for o in op_schema.outputs: - # Current ops do not have optional outputs if they have variable number of outputs - if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic: - return - optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional) - # If no optional outputs in spec, skip delete operations - if len([o == 1 for o in optional_info]) == 0: - return - - for i, out in enumerate(n.output): - if out not in used and optional_info[i] is True: - n.output[i] = "" - # Only delete trailing unused optional outputs - for o in n.output[::-1]: # type: ignore[assignment] - if o == "": - n.output.pop() - else: - return - - -def compute_used_in_node(n: onnx.NodeProto) -> set[str]: - used = {n for n in n.input if n != ""} - for attr in n.attribute: - if attr.HasField("g"): - used |= compute_used_in_graph(attr.g) - elif len(attr.graphs) > 0: - for graph in attr.graphs: - used |= compute_used_in_graph(graph) - return used - - -def compute_used_in_graph(g: onnx.GraphProto) -> set[str]: - used = set() - for n in g.node: - used |= compute_used_in_node(n) - return used - - -def process_nodes( - nodes: RepeatedCompositeFieldContainer[onnx.NodeProto], - used: set, - opset_import: Sequence[onnx.OperatorSetIdProto], -) -> int: - count = 0 - i = len(nodes) - 1 - while i >= 0: - node = nodes[i] - remove_unused_optional_outputs(node, used, opset_import) - used_outputs = [x for x in node.output if x in used] - if not used_outputs: - del nodes[i] - count += 1 - i -= 1 - continue - for attr in node.attribute: - if attr.HasField("g"): - process_graph(attr.g, opset_import) - elif len(attr.graphs) > 0: - for graph in attr.graphs: - process_graph(graph, opset_import) - used |= compute_used_in_node(node) - i -= 1 - return count - - -def process_graph( - graph: onnx.GraphProto, opset_import: Sequence[onnx.OperatorSetIdProto] -) -> int: - used = {output.name for output in graph.output} - - count = process_nodes(graph.node, used, opset_import) - - new_initializers = [] - for init in graph.initializer: - if init.name not in used: - count += 1 - continue - new_initializers.append(init) - del graph.initializer[:] - graph.initializer.extend(new_initializers) - return count - - -def process_function( - function: onnx.FunctionProto, opset_import: Sequence[onnx.OperatorSetIdProto] -) -> int: - used = set(function.output) - - return process_nodes(function.node, used, opset_import) - - -def remove_unused_nodes(model: onnx.ModelProto) -> None: - """Removes unused nodes from the model.""" - count = process_graph(model.graph, model.opset_import) - for function in model.functions: - count += process_function(function, model.opset_import) - - logger.info("Removed %s unused nodes", count) diff --git a/onnxscript/optimizer/_legacy/_simple_function_folding.py b/onnxscript/optimizer/_legacy/_simple_function_folding.py deleted file mode 100644 index 829bae9d62..0000000000 --- a/onnxscript/optimizer/_legacy/_simple_function_folding.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Inlines the function if it only contains very few number of nodes.""" - -from __future__ import annotations - -import logging -from typing import Sequence - -import onnx - -import onnxscript._legacy_ir as ir -from onnxscript._legacy_ir import visitor -from onnxscript.optimizer._legacy import _remove_unused_proto - -logger = logging.getLogger(__name__) - - -class FunctionInliner(visitor.FunctionCallsiteProtoTransformer): - counts: dict[ir.FunctionId, int] - - def __init__(self, node_count: int) -> None: - super().__init__() - self._node_count = node_count - - def _gather_function_metadata(self, model: onnx.ModelProto) -> None: - super()._gather_function_metadata(model) - self._function_renamer._postfix = "inlined" - - def visit_model(self, model: onnx.ModelProto) -> None: - self.counts = {} - - super().visit_model(model) - - def should_inline_function(self, function: onnx.FunctionProto) -> bool: - return len(function.node) <= self._node_count - - def process_function_node( - self, node: onnx.NodeProto - ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: - # Recursively process sub nodes first. - function_id = (node.domain, node.op_type, getattr(node, "overload", "")) - function = self._functions[function_id] - replacement, new_function = super().process_function_node(node) - function = new_function if new_function else function - - if self.should_inline_function(function): - self.enter_function_scope(function) - sub_scope = self.exit_function_scope(function) - new_nodes = [] - - formal_outs = function.output - actual_outs = node.output - formal_ins = function.input - actual_ins = node.input - # TODO: Potential collision when actual is "". - # formal.name may collide with existing value names. - input_renamings = dict(zip(formal_ins, actual_ins)) - if len(actual_ins) < len(formal_ins): - input_renamings.update(dict.fromkeys(formal_ins[len(actual_ins) :], "")) - output_renamings = { - formal: actual - for formal, actual in zip(formal_outs, actual_outs) - if actual != "" - } - renamings = {**input_renamings, **output_renamings} - - logger.debug("renamings function %s: %s", function.name, renamings) - - def rename(name: str) -> str: - if name == "": - return name - new_name = renamings.get(name) - if new_name is None: - new_name = f"{node.name}_{name}" - logger.debug("renaming %s to %s", name, new_name) - if (ir_value := sub_scope.lookup(name)) is not None: - if ir_value.tensor_shape_proto() is not None and ir_value.type is not None: - ir_value.name = new_name - self.bind(new_name, ir_value) - return new_name - - ref_attrs = {attr.name: attr for attr in node.attribute} - # logger.debug("inlining simple function %s. Ref attrs: %s", function.name, ref_attrs) - - def fill_in_ref(attr: onnx.AttributeProto) -> onnx.AttributeProto: - if attr.ref_attr_name: - new_attr = onnx.AttributeProto() - new_attr.CopyFrom(ref_attrs[attr.ref_attr_name]) - new_attr.name = attr.name - return new_attr - return attr - - def update_graph_attribute( - attr: onnx.AttributeProto, - ) -> onnx.AttributeProto: - if attr.g: - new_attr = onnx.AttributeProto() - new_attr.CopyFrom(attr) - for node in new_attr.g.node: - node.input[:] = [rename(name) for name in node.input] - node.output[:] = [rename(name) for name in node.output] - new_attrs = [] - for attr in node.attribute: - new_attrs.append(update_attribute(attr)) - del node.attribute[:] - node.attribute.extend(new_attrs) - for vi_proto in new_attr.g.input: - vi_proto.name = rename(vi_proto.name) - for vi_proto in new_attr.g.output: - vi_proto.name = rename(vi_proto.name) - return new_attr - return attr - - def update_attribute(attr: onnx.AttributeProto) -> onnx.AttributeProto: - new_attr = fill_in_ref(attr) - new_attr = update_graph_attribute(new_attr) - return new_attr - - for sub_node in function.node: - # logger.debug("inlining simple function. old node: %s", sub_node) - new_node = onnx.NodeProto() - new_node.CopyFrom(sub_node) - new_node.input[:] = [rename(name) for name in new_node.input] - new_node.output[:] = [rename(name) for name in new_node.output] - del new_node.attribute[:] - for attr in sub_node.attribute: - new_node.attribute.append(update_attribute(attr)) - # Avoid name collision. - new_node.name = f"{node.name}_{new_node.name}" - # logger.debug("inlining simple function. new node: %s", new_node) - new_nodes.append(new_node) - - self.counts.setdefault(function_id, 0) - self.counts[function_id] += 1 - - return new_nodes, None - - return replacement, new_function - - -class SelectedFunctionInliner(FunctionInliner): - def __init__(self, functions_to_inline: Sequence[onnx.FunctionProto]): - super().__init__(node_count=0) # node_count unused. - self._functions_to_inline = functions_to_inline - - def should_inline_function(self, function: onnx.FunctionProto) -> bool: - return function in self._functions_to_inline - - -class FindFunctionWithUnusedOutputsVisitor(visitor.ProtoVisitor): - def __init__(self) -> None: - super().__init__() - self._function_with_unused_outputs: dict[ir.FunctionId, onnx.FunctionProto] = {} - self._functions: dict[ir.FunctionId, onnx.FunctionProto] = {} - self._used_nodes: list[onnx.NodeProto] = [] - - def _find_nodes_with_any_unused_output( - self, nodes: Sequence[onnx.NodeProto], used_values: set[str] - ) -> list[onnx.NodeProto]: - target_nodes = [] - for i in range(len(nodes) - 1, -1, -1): - node = nodes[i] - if any(x not in used_values for x in node.output): - # Any unused output means the node is a target node. - target_nodes.append(node) - if all(x not in used_values for x in node.output): - # All unused output means the node is not used at all. - # Hence do not update used_values with the node's inputs. - continue - used_values |= _remove_unused_proto.compute_used_in_node(node) - return target_nodes - - def visit_model(self, model: onnx.ModelProto) -> None: - used_values = {output.name for output in model.graph.output} - target_nodes = self._find_nodes_with_any_unused_output(model.graph.node, used_values) - - for function in model.functions: - self._functions[ - (function.domain, function.name, getattr(function, "overload", "")) - ] = function - used_values = set(function.output) - target_nodes.extend( - self._find_nodes_with_any_unused_output(function.node, used_values) - ) - - for node in target_nodes: - if visitor.is_local_function_node(node, self._functions): - function_id = (node.domain, node.op_type, getattr(node, "overload", "")) - self._function_with_unused_outputs[function_id] = self._functions[function_id] - - logger.info( - "Found %s function nodes that have unused outputs.", - len(self._function_with_unused_outputs), - ) - for key in self._function_with_unused_outputs: - logger.info("Function node with unused outputs: %s::%s", key[0], key[1]) - - @property - def function_with_unused_outputs(self) -> dict[ir.FunctionId, onnx.FunctionProto]: - return self._function_with_unused_outputs - - -def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool: - """Inlines simple functions based on a node count threshold""" - inliner = FunctionInliner(node_count) - inliner.visit_model(model) - logger.info( - "inlined %s simple functions based on node count threshold %s.", - len(inliner.counts), - node_count, - ) - for op in inliner.counts: - logger.info( - "Inlined simple function '%s::%s' %s times.", - op[0], - op[1], - inliner.counts[op], - ) - return inliner.modified - - -def inline_functions_with_unused_outputs(model: onnx.ModelProto) -> bool: - """Inlines function nodes that have unused outputs.""" - # TODO: Use onnx.inliner after 1.16. - # This visitor based inliner is used to ensure the function inner value info remains consistent. - visitor = FindFunctionWithUnusedOutputsVisitor() - visitor.visit_model(model) - # FIXME: Fix the type of the argument passed into SelectedFunctionInliner - inliner = SelectedFunctionInliner(visitor.function_with_unused_outputs.values()) # type: ignore[arg-type] - inliner.visit_model(model) - logger.info( - "inlined %s function nodes that have unused outputs.", - len(inliner.counts), - ) - for op in inliner.counts: - logger.info( - "Inlined function '%s::%s' %s times.", - op[0], - op[1], - inliner.counts[op], - ) - return inliner.modified diff --git a/onnxscript/optimizer/_legacy/_simple_function_folding_test.py b/onnxscript/optimizer/_legacy/_simple_function_folding_test.py deleted file mode 100644 index 8e0dcf94f5..0000000000 --- a/onnxscript/optimizer/_legacy/_simple_function_folding_test.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import unittest - -import onnx - -from onnxscript import ir -from onnxscript.ir.passes.common import unused_removal -from onnxscript.optimizer._legacy import _simple_function_folding - - -def _remove_unused_functions(model_proto: onnx.ModelProto) -> onnx.ModelProto: - model = ir.serde.deserialize_model(model_proto) - model = unused_removal.RemoveUnusedFunctionsPass()(model).model - return ir.serde.serialize_model(model) - - -class SingleNodeFunctionFoldingTest(unittest.TestCase): - def test_fold_single_node_function(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["this" : 1, "" : 18] -> -func ( x, y) => ( return_val) { - tmp = this.foldable (x) - return_val = Add (tmp, y) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable (x) => (return_val) -{ - return_val = Identity (x) -} - """ - ) - - _simple_function_folding.inline_simple_functions(model) - model = _remove_unused_functions(model) - - self.assertEqual(len(model.functions), 0) - - def test_fold_single_node_function_ref_attr(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["this" : 1, "" : 18] -> -func ( x, y, z) => ( return_val) { - tmp = this.foldable (x, y) - return_val = Add (tmp, z) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable (x, y) => (return_val) -{ - return_val = Concat (x, y) -} - """ - ) - - _simple_function_folding.inline_simple_functions(model) - model = _remove_unused_functions(model) - - self.assertEqual(len(model.functions), 0) - self.assertFalse(model.graph.node[0].attribute[0].ref_attr_name) - self.assertEqual(model.graph.node[0].attribute[0].name, "axis") - - def test_fold_single_node_function_nested(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["this" : 1, "" : 18] -> -func ( x, y, z) => ( return_val) { - tmp = this.non_foldable (x, y) - return_val = Add (tmp, z) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable (x, y) => (return_val) -{ - return_val = Concat (x, y) -} -< - domain: "this", - opset_import: ["this" : 1,"" : 18] -> -non_foldable (x, y) => (return_val) -{ - tmp = this.foldable (x, y) - tmp_0 = this.foldable (x, y) - return_val = Add (tmp, tmp_0) -} - """ - ) - - _simple_function_folding.inline_simple_functions(model) - model = _remove_unused_functions(model) - - self.assertEqual(len(model.functions), 1) - self.assertEqual(model.functions[0].node[0].op_type, "Concat") - self.assertEqual(model.functions[0].node[1].op_type, "Concat") - - def test_fold_single_node_function_create_new_nodes_with_correct_attributes(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 9, - opset_import: ["this" : 1, "" : 21] -> -func (float[1,512] x) => ( a, b, c) { - a = this.prim_cast (x) - b = this.prim_cast (x) - c = this.prim_cast (x) -} -< - domain: "this", - opset_import: ["" : 18] -> -prim_cast (x) => (return_val) -{ - return_val = Cast (x) -} - """ - ) - _simple_function_folding.inline_simple_functions(model) - model = _remove_unused_functions(model) - self.assertEqual(len(model.functions), 0) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[0].attribute[0].i, 10) - self.assertEqual(model.graph.node[1].attribute[0].i, 6) - self.assertEqual(model.graph.node[2].attribute[0].i, 7) - - def test_fold_nested_if_function_succeeds(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 9, - opset_import: ["this" : 1, "" : 21] -> -func (float[1,512] x, float[1,512] y) => ( out) { - out = this.foldable_func (x, y) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable_func (x, y) => (z_6) -{ - cond = Constant () - z_6 = If (cond) ( z_2) { - cond_0 = Not (cond) - z_2 = If (cond_0) ( z) { - z = Add (x, x) - }, else_branch: graph = elseGraph_5 () => ( z_1) { - z_1 = Identity (x) - }> - }, else_branch: graph = elseGraph_4 () => ( z_5) { - z_5 = If (cond) ( z_3) { - z_3 = Add (y, y) - }, else_branch: graph = elseGraph_10 () => ( z_4) { - z_4 = Add (x, y) - }> - }> -} - """ - ) - - _simple_function_folding.inline_simple_functions(model) - model = _remove_unused_functions(model) - - self.assertEqual(len(model.functions), 0) - self.assertEqual(len(model.graph.node), 2) - self.assertEqual(model.graph.node[1].op_type, "If") - - def test_fold_function_with_unused_output(self): - model = onnx.parser.parse_model( - """ -< - ir_version: 8, - opset_import: ["this" : 1, "" : 18] -> -func ( x, y, z) => ( return_val) { - tmp = this.non_foldable (x, y) - return_val = Add (tmp, z) -} -< - domain: "this", - opset_import: ["" : 18] -> -foldable (x, y) => (return_val, unused, unused1) -{ - return_val = Concat (x, y) - unused = Identity (x) - unused1 = Identity (y) -} -< - domain: "this", - opset_import: ["this" : 1,"" : 18] -> -non_foldable (x, y) => (return_val) -{ - tmp, unused, unused1 = this.foldable (x, y) - tmp_0, unused2, unused3 = this.foldable (x, y) - return_val = Add (tmp, tmp_0) -} - """ - ) - - _simple_function_folding.inline_functions_with_unused_outputs(model) - model = _remove_unused_functions(model) - self.assertEqual(len(model.functions), 1) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/optimizer/_legacy/constant_folding.py b/onnxscript/optimizer/_legacy/constant_folding.py deleted file mode 100644 index d30a8c9cc8..0000000000 --- a/onnxscript/optimizer/_legacy/constant_folding.py +++ /dev/null @@ -1,293 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import logging -from typing import Any, Sequence - -import numpy as np -import onnx -import onnx.reference.ops - -import onnxscript._legacy_ir as ir -import onnxscript.optimizer._constant_folding as _constant_folding -from onnxscript._legacy_ir import visitor -from onnxscript.optimizer._legacy import evaluator -from onnxscript.utils.utils import ( - is_control_flow_op, - is_onnx_domain, -) - -logger = logging.getLogger(__name__) - -# Ops excluded from constant-propagation: -# * Random ops, which are not deterministic (checked below) -# * Control flow ops (checked by presence of graph-attribute) - -onnx_domain = frozenset({"", "onnx.ai"}) - - -def is_non_deterministic_op(node: onnx.NodeProto) -> bool: - non_deterministic_ops = _constant_folding.non_deterministic_ops - return node.op_type in non_deterministic_ops and is_onnx_domain(node.domain) - - -def is_constant_op(node: onnx.NodeProto) -> bool: - return node.op_type in {"Constant", "ConstantOfShape"} and is_onnx_domain(node.domain) - - -class ConstantFolder(visitor.FunctionCallsiteProtoTransformer): - def __init__( - self, - registry: evaluator.PartialEvaluatorRegistry, - external_data_folder: str, - *, - do_shape_inference: bool, - ) -> None: - self.registry = registry - # TODO: make evaluator a parameter - self.evaluate = evaluator.reference_evaluator.evaluate - self._do_shape_inference = do_shape_inference - self._init() - super().__init__(external_data_folder, do_shape_inference=do_shape_inference) - - def _init(self) -> None: - self.counts = {} - self.sizes = {} - - def add_count(self, op: str, size: int = 1): - self.counts[op] = self.counts.get(op, 0) + 1 - self.sizes[op] = self.sizes.get(op, 0) + size - - def foldable_value(self, name: str, value): - """Checks if a runtime-constant can and should be folded into the graph. - - We fold constants only if they are tensors (not lists of tensors, for example) - and have size below desired limit. - """ - if value is ir.NotConstant: - return None - - if not isinstance(value, np.ndarray): - # ONNX does not have a way to represent non-tensor constants, eg. a sequence. - # So, a constant-value of type sequence is not folded, but it can be used - # to optimize subsequent operations when possible. - logger.info( - "Skip storing constant folded value %s due to unsupported type %s.", - name, - type(value), - ) - return None - - if value.nbytes > _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT: - logger.info( - "Skip storing constant folded nvalue %s due to large size %s.", - name, - value.nbytes, - ) - return None - - return onnx.numpy_helper.from_array(value, name) - - def new_constant(self, name, value): - if isinstance(value, (int, float, np.ScalarType)): - value = np.array(value) - - info = self.lookup_or_create(name) - info.value = value - - tensor = self.foldable_value(name, value) - if tensor is None: - return None - - logger.debug( - "New constant for value %s dtype: %s shape: %s", - name, - value.dtype, - value.shape, - ) - info.type = onnx.helper.make_tensor_type_proto( - onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape - ) - node = onnx.helper.make_node("Constant", inputs=[], outputs=[name], value=tensor) - return [node] - - def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]: - if self.scopes.current_scope().current_function_scope(): - # Need to resolve ref_attr_name if inside a function. - attr_dict = {} - for attribute in attributes: - concrete_attribute = ( - self.lookup_ref_attribute(attribute.ref_attr_name) - if attribute.ref_attr_name - else attribute - ) - if concrete_attribute is None: - continue - attr_dict[attribute.name] = onnx.helper.get_attribute_value(concrete_attribute) - return attr_dict - return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes} - - def replace_copy(self, node: onnx.NodeProto) -> None: - for i in range(len(node.input)): - input = self.get_input(node, i) - if input is not None and input.is_copy(): - old_value = self.lookup_or_create(input.name) - assert isinstance(input.symbolic_value, str) - new_value = self.lookup_or_create(input.symbolic_value) - # Merge meta info. It is important to do if the new value - # is created by evaluator, and thus carries zero meta info. - # Since this is a copy, the meta info should be the same. - new_value.identity_merge_from(old_value) - node.input[i] = input.symbolic_value - - def process_function_outputs(self, function: onnx.FunctionProto) -> bool: - # Resolve copy for function subgraph output. - # Avoid copy of function subgraph input, because it is illegal for a direct edge - # from function input to function output. - prohibited_value_set = set(function.input) - updated = False - for i, output_name in enumerate(function.output): - output = self.lookup(output_name) - if ( - output is not None - and output.is_copy() - and output.symbolic_value not in prohibited_value_set - ): - old_value = self.lookup_or_create(output.name) - assert isinstance(output.symbolic_value, str) - new_value = self.lookup_or_create(output.symbolic_value) - new_value.identity_merge_from(old_value) - function.output[i] = output.symbolic_value - updated = True - return updated - - def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: - self.replace_copy(node) - - super().process_node(node) - - inputs = [self.lookup(x) for x in node.input] - attrs = self.convert_attributes(node.attribute) - - domain = node.domain - op = node.op_type - version = self.lookup_version(domain) - - # if any(x is Undefined for x in inputs): - # return None - # Above check ensures that none of the optimizations below need to handle - # undefined inputs - - op_optimizers = self.registry.lookup_evaluators(domain, op, version) - for optimizer in op_optimizers: - assert optimizer - output = optimizer(self, node) - if output is None: - continue - if isinstance(output, list): - return output - else: - # Currently handles single output only - self.add_count(node.op_type, output.size) - return self.new_constant(node.output[0], output) - - if is_control_flow_op(node) or is_non_deterministic_op(node): - return None - - input_values = [x.value if x is not None else None for x in inputs] - if any(x is ir.NotConstant for x in input_values): - return None - - input_types = [x.type for x in inputs if x is not None] - - def is_excluded_type(type_proto: onnx.TypeProto | None) -> bool: - if type_proto is None: - return True - if type_proto.HasField("tensor_type"): - return type_proto.tensor_type.elem_type in { - onnx.TensorProto.BFLOAT16, - onnx.TensorProto.FLOAT8E4M3FN, - onnx.TensorProto.FLOAT8E4M3FNUZ, - onnx.TensorProto.FLOAT8E5M2, - onnx.TensorProto.FLOAT8E5M2FNUZ, - } - return False - - if any(is_excluded_type(x) for x in input_types): - return None - - outputs = self.evaluate(domain, op, version, *input_values, **attrs) - # TODO: what if evaluated value is None? - if outputs is None: - return None - if len(node.output) == 1 and not isinstance(outputs, (tuple, list)): - replacement = self.new_constant(node.output[0], outputs) - if is_constant_op(node): - return None - self.add_count(op, outputs.size) - return replacement - else: - logger.warning("Skipping constant folding for op %s with multiple outputs.", op) - return None - - def process_function_node( - self, node: onnx.NodeProto - ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: - self.replace_copy(node) - - _, new_function = super().process_function_node(node) - - # Replace function node with Constant if all outputs are constants - ir_values = [self.lookup(output_name) for output_name in node.output] - tensors = [ - self.foldable_value(output_name, ir_value.value if ir_value is not None else None) - for output_name, ir_value in zip(node.output, ir_values) - ] - if all(tensor is not None for tensor in tensors): - replacements = [] - for output_name, tensor in zip(node.output, tensors): - newnode = onnx.helper.make_node( - "Constant", inputs=[], outputs=[output_name], value=tensor - ) - replacements.append(newnode) - logger.debug( - "Function node replacements: node %s %s (%s/%s)", - node.name, - [replacement.output for replacement in replacements], - len(replacements), - len(node.output), - ) - return replacements, new_function - return None, new_function - - def visit_model(self, model: onnx.ModelProto) -> None: - self._init() - - super().visit_model(model) - - -def fold_constants( - model: onnx.ModelProto, - external_data_folder: str = "", - *, - onnx_shape_inference: bool = False, -) -> bool: - """ - Applies constant folding optimization to the model. - Returns true iff the model was modified. - """ - folder = ConstantFolder( - evaluator.registry, - external_data_folder, - do_shape_inference=onnx_shape_inference, - ) - folder.visit_model(model) - for op in folder.counts: - logger.info( - "Constant-folded '%s' %s times, with %s size.", - op, - folder.counts[op], - folder.sizes[op], - ) - return folder.modified diff --git a/onnxscript/optimizer/_legacy/evaluator.py b/onnxscript/optimizer/_legacy/evaluator.py deleted file mode 100644 index 2b638eab30..0000000000 --- a/onnxscript/optimizer/_legacy/evaluator.py +++ /dev/null @@ -1,439 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# ------------------------------------------------------------------------- - -from __future__ import annotations - -import dataclasses -import logging -import math -from typing import Any, Callable, Protocol, Sequence, Union - -import numpy as np -import onnx -import onnx.reference.ops - -import onnxscript._legacy_ir as ir -from onnxscript.utils.utils import ( - get_node_attr_value, -) - -logger = logging.getLogger(__name__) - -# "Standard" evaluators are used to perform constant-folding. -# The API below works only for non-control-flow ops (ops without any graph-attributes). -# This currently used ONNX's reference implementation. But we could also -# use ORT's implementation if we want to. - - -class ReferenceEvaluator: - def get_evaluator(self, domain: str, op: str, version: int) -> callable | None: - try: - op_impl_class = onnx.reference.ops.load_op(domain, op, version) - return op_impl_class.eval # noqa: TRY300 - except Exception: - return None - - def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: - logger.debug("Evaluating %s::%s", domain, op) - evaluator = self.get_evaluator(domain, op, version) - if evaluator is None: - return None - return evaluator(*args, **kwargs) - - -reference_evaluator = ReferenceEvaluator() - -# The "partial evaluators" below are non-standard evaluators. They are used to perform -# partial evaluation and/or static program analysis (abstract interpretation). - - -class IRContext(Protocol): - """A class that represents the context for partial evaluation. - - This is a placeholder, subject to simplification when a proper IR is defined. - """ - - def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... - - def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... - - def input_const_value(self, node: onnx.NodeProto, index: int) -> ir.ConcreteValue: ... - - def input_shape( - self, node: onnx.NodeProto, index: int - ) -> onnx.TensorShapeProto | None: ... - - def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None: ... - - def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None: ... - - def lookup_version(self, domain: str) -> int: ... - - def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict: ... - - def new_constant(self, name: str, value: Any) -> Sequence[onnx.NodeProto] | None: ... - - -# A partial-evaluator function takes an IRContext and a node, and returns a list of -# replacement nodes or None (if no replacement is needed). We return None instead -# of [input node] so the caller is aware that the node is not replaced. If the node -# is replaced, the caller will recursively visit the replacement nodes to process them. - -PartialEvaluatorFunction = Union[ - Callable[[IRContext, onnx.NodeProto], Sequence[onnx.NodeProto]], None -] - - -@dataclasses.dataclass -class PartialEvaluator: - """A class that represents a partial-evaluator for a particular op. - - It is applicable for a specific version range (min_version, max_version) of the op. - The min_version and max_version can be None, indicating that there is no version - constraint in that direction. - """ - - min_version: int | None - max_version: int | None - function: PartialEvaluatorFunction - - def valid_for(self, version: int) -> bool: - """Returns True if this evaluator is applicable for the given version.""" - return (self.min_version is None or version >= self.min_version) and ( - self.max_version is None or version <= self.max_version - ) - - -class PartialEvaluatorRegistry: - """A class that maintains a registry of evaluators for ops.""" - - def __init__(self): - self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} - - def lookup_evaluators(self, domain: str, opname: str, version: int): - evaluator_list = self.op_evaluators.get((domain, opname), []) - return [ - evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) - ] - - def register(self, opname: str, domain: str = "", version=None): - if (domain, opname) not in self.op_evaluators: - evaluator_list = [] - self.op_evaluators[(domain, opname)] = evaluator_list - else: - evaluator_list = self.op_evaluators[(domain, opname)] - if version is None: - min_version = None - max_version = None - elif isinstance(version, int): - min_version = version - max_version = version - elif isinstance(version, tuple): - min_version, max_version = version - - def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: - evaluator_list.append(PartialEvaluator(min_version, max_version, function)) - return function - - return decorator - - -registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() - -register = registry.register - - -def get_bool_value(val) -> bool | None: - if isinstance(val, bool): - return val - if isinstance(val, np.bool_): - return bool(val) - if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: - return val.item(0) - return None - - -def get_size_info(type: onnx.TypeProto) -> np.ndarray | None: - if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): - size = 1 - for d in type.tensor_type.shape.dim: - size *= d.dim_value - return np.array(size, dtype=np.int64) - return None - - -def get_dim_info(type: onnx.TypeProto, dim: int) -> int | None: - if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - rank = len(type.tensor_type.shape.dim) - dim = dim if dim >= 0 else dim + rank - if dim < 0 or dim >= rank: - return None - if type.tensor_type.shape.dim[dim].HasField("dim_value"): - return type.tensor_type.shape.dim[dim].dim_value - return None - - -@register("Cast") -def cast(context: IRContext, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: - if context.input_shape(node, 0) is not None: - output_value = context.get_output(node, 0) - output_value.type = onnx.TypeProto() - output_value.type.CopyFrom(context.input_type(node, 0)) - output_value.type.tensor_type.elem_type = node.attribute[0].i - return None - - -@register("CastLike") -def cast_like(context: IRContext, node: onnx.NodeProto): - source_element_type = context.input_element_type(node, 0) - target_element_type = context.input_element_type(node, 1) - - if target_element_type is None: - return None - if source_element_type == target_element_type: - node.op_type = "Identity" - del node.input[1] - return [node] - - node.op_type = "Cast" - del node.input[1] - del node.attribute[:] - node.attribute.append(onnx.helper.make_attribute("to", target_element_type)) - return [node] - - -@register("Shape") -def shape(context: IRContext, node: onnx.NodeProto): - shape = context.input_shape(node, 0) - if shape is None: - return None - start = get_node_attr_value(node, "start", 0) - end = get_node_attr_value(node, "end", None) - shape_slice = shape.dim[start:end] - if all(d.HasField("dim_value") for d in shape_slice): - return np.array([d.dim_value for d in shape_slice], dtype=np.int64) - return None - - -@register("Size") -def size(context: IRContext, node: onnx.NodeProto): - type = context.input_type(node, 0) - size = get_size_info(type) if type is not None else None - return size - - -@register("If") -def if_op(context: IRContext, node: onnx.NodeProto): - cond = context.input_const_value(node, 0) - if cond is ir.NotConstant: - # Visitor will recursively visit subgraphs to constant-fold them. - return None - cond = get_bool_value(cond) - if cond is not None: - # cond is a constant-value: inline the branch - branch = "then_branch" if cond else "else_branch" - graph = onnx.helper.get_node_attr_value(node, branch) - - formal_outs = list(graph.output) - actual_outs = node.output - renamings = { - formal.name: actual - for formal, actual in zip(formal_outs, actual_outs) - if actual != "" - } - # TODO: Extend renaming to intermediate values. - - def rename(name): - return renamings.get(name, name) - - for sub_node in graph.node: - # TODO: handle renaming inside subgraphs in nodes - sub_node.input[:] = [rename(name) for name in sub_node.input] - sub_node.output[:] = [rename(name) for name in sub_node.output] - # Avoid name collision. - sub_node.name = f"{node.name}_{sub_node.name}" - - # TODO: we should handle initializers as well! - return list(graph.node) - return None - - -@register("Identity") -def identity(context: IRContext, node: onnx.NodeProto): - input = context.get_input(node, 0) - output = context.get_output(node, 0) - if input is not None and output is not None: - output.symbolic_value = input.name - - -@register("SequenceConstruct") -def sequence_construct( - context: IRContext, node: onnx.NodeProto -) -> Sequence[onnx.NodeProto] | None: - output = context.get_output(node, 0) - if output is not None: - output.symbolic_value = list(node.input) - return None - - -@register("ConcatFromSequence") -def concat_from_sequence( - context: IRContext, node: onnx.NodeProto -) -> Sequence[onnx.NodeProto] | None: - input = context.get_input(node, 0) - attrs = context.convert_attributes(node.attribute) - new_axis = attrs.get("new_axis", 0) - if input is not None and isinstance(input.symbolic_value, list): - if new_axis == 0: - node.op_type = "Concat" - node.input[:] = input.symbolic_value - logger.debug("ConcatFromSequence => Concat: %s", node.input) - for i in range(len(node.attribute)): - if node.attribute[i].name == "new_axis": - del node.attribute[i] - return [node] - return [node] - if new_axis == 1: - # Unsqueeze the inputs with concat axis if new_axis is 1 - axis = attrs.get("axis", None) - assert axis is not None - output = context.get_output(node, 0) - axis_node = context.new_constant(f"{output.name}_axis", np.array([axis]))[0] - unsqueeze_nodes = [] - for node_input in input.symbolic_value: - unsqueeze_node = onnx.helper.make_node( - "Unsqueeze", - [node_input, axis_node.output[0]], - [f"{node_input}_unsqueeze"], - ) - unsqueeze_nodes.append(unsqueeze_node) - unsqueeze_outputs = [n.output[0] for n in unsqueeze_nodes] - unsqueeze_nodes = [axis_node, *unsqueeze_nodes] - - # Send unsqueezed outputs to Concat - node.input[:] = unsqueeze_outputs - node.op_type = "Concat" - logger.debug( - "ConcatFromSequence => UnSqueeze %s + Concat %s", - unsqueeze_outputs, - node.input, - ) - for i in range(len(node.attribute)): - if node.attribute[i].name == "new_axis": - del node.attribute[i] - break - return [*unsqueeze_nodes, node] - return None - - -@register("SplitToSequence") -def split_to_sequence( - context: IRContext, node: onnx.NodeProto -) -> Sequence[onnx.NodeProto] | None: - """Rewriting pattern. - - From - - splits = onnx::SplitToSequence(input, split, axis=axis) - - to - - split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) - splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) - - or - - split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) - splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) - - where number of output tensors in `splits` is statically known. - onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. - This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. - """ - input = context.get_input(node, 0) - split = context.get_input(node, 1) - attrs = context.convert_attributes(node.attribute) - output = context.get_output(node, 0) - - if input is None or split is None or output is None: - return None - - axis = attrs.get("axis", 0) - if input.type is None: - return None - split_dimension_size = get_dim_info(input.type, axis) - if split_dimension_size is None: - return None - - split_value = split.value - if split_value is None or split_value is ir.NotConstant: - return None - assert isinstance(split_value, np.ndarray) - - if split_value.ndim == 0: - # split into chunks all of size 'split' if possible. - num_outputs = math.ceil(split_dimension_size / split_value.item()) - split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_node = onnx.helper.make_node( - "Split", - [input.name], - split_outputs, - axis=axis, - num_outputs=num_outputs, - ) - else: - # split into 'size(split)' chunks - num_outputs = split_value.size - split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_node = onnx.helper.make_node( - "Split", - [input.name, split.name], - split_outputs, - axis=axis, - ) - - keepdims = attrs.get("keepdims", 1) - squeeze_nodes = [] - if keepdims == 0: - # squeeze the split dimension if keepdims is 0 - axis_node = context.new_constant(f"{output.name}_axis", np.array([axis]))[0] - for i in range(num_outputs): - squeeze_node = onnx.helper.make_node( - "Squeeze", - [split_outputs[i], axis_node.output[0]], - [f"{split_outputs[i]}_squeeze"], - ) - squeeze_nodes.append(squeeze_node) - split_outputs = [n.output[0] for n in squeeze_nodes] - squeeze_nodes = [axis_node, *squeeze_nodes] - - node.op_type = "SequenceConstruct" - node.input[:] = split_outputs - del node.attribute[:] - logger.debug( - "SplitToSequence => Split %s + SequenceConstruct %s", - split_node.input, - node.input, - ) - return [split_node, *squeeze_nodes, node] - - -@register("SequenceAt") -def sequence_at(context: IRContext, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: - input = context.get_input(node, 0) - position = context.get_input(node, 1) - output = context.get_output(node, 0) - if input is not None and position is not None: - input_vals = input.symbolic_value - position_val = position.value - if isinstance(input_vals, list) and position_val is not None: - output.symbolic_value = input_vals[position_val] - logger.debug("SequenceAt %s => %s", input, output.symbolic_value) - new_node = onnx.helper.make_node( - "Identity", [output.symbolic_value], [output.name] - ) - return [new_node] - return None diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index 562cdc9690..3aaba1b057 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -21,6 +21,7 @@ def optimize_ir( stop_if_no_change: bool = True, input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, + inline: bool = True, ) -> None: """Optimizes a model. @@ -32,11 +33,10 @@ def optimize_ir( greater than this. Does not apply to special ops like Shape() and Size(). output_size_limit: Will not rewrite any foldable-op into a Constant op if the size of the output tensor is greater than this. - stop_if_no_change: Not supported currently (has no effect). Meant to stop the - outer optimization loop if no change is detected in one iteration. + stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. + inline: If True, inlines all functions in the model. """ - optimizer_pass = ir.passes.Sequential( - onnxscript.ir.passes.common.inliner.InlinePass(), + passes = [ ir.passes.PassManager( [ _constant_folding.FoldConstantsPass( @@ -54,7 +54,11 @@ def optimize_ir( ), onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(), onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass(), - ) + ] + if inline: + # Inline all functions first before optimizing + passes = [onnxscript.ir.passes.common.inliner.InlinePass(), *passes] + optimizer_pass = ir.passes.Sequential(*passes) assert optimizer_pass.in_place result = optimizer_pass(model) assert result.model is model diff --git a/onnxscript/optimizer/_remove_unused_function.py b/onnxscript/optimizer/_remove_unused_function.py deleted file mode 100644 index 8d960d983f..0000000000 --- a/onnxscript/optimizer/_remove_unused_function.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import logging -from typing import TypeVar - -import onnx - -from onnxscript import ir - -logger = logging.getLogger(__name__) - - -TModel = TypeVar("TModel", ir.Model, onnx.ModelProto) diff --git a/onnxscript/tools/transformers_models/phi_test.py b/onnxscript/tools/transformers_models/phi_test.py index 501004bc95..f2b5f9ff8f 100644 --- a/onnxscript/tools/transformers_models/phi_test.py +++ b/onnxscript/tools/transformers_models/phi_test.py @@ -10,8 +10,6 @@ import onnxruntime import torch -import onnxscript.optimizer -import onnxscript.rewriter import onnxscript.tools.training_helper import onnxscript.tools.transformers_models import onnxscript.tools.transformers_models.phi @@ -83,6 +81,9 @@ def test_phi_export_cuda(self): @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") @unittest.skipIf(not has_transformers(), reason="transformers is missing") + @unittest.skipIf( + not hasattr(onnxruntime, "training"), reason="ORT training removed since 1.22" + ) @ignore_warnings(UserWarning) def test_phi_dort_static(self): model, input_tensors_many, _ = onnxscript.tools.transformers_models.phi.get_phi_model() diff --git a/tests/optimizer/test_models.py b/tests/optimizer/test_models.py index 679898ed04..ed3a68bce1 100644 --- a/tests/optimizer/test_models.py +++ b/tests/optimizer/test_models.py @@ -16,7 +16,10 @@ from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils -_SKIP_TABLE = {} +_SKIP_TABLE = { + "resnet18": "fixme: ORT aborts when loading the model - https://github.com/microsoft/onnxruntime/issues/24473", + "mobilenetv2_100": "fixme: ORT aborts when loading the model - https://github.com/microsoft/onnxruntime/issues/24473", +} model_folder_path = ( pathlib.Path(__file__).resolve().parent.parent.parent / "testdata" / "e2e_models"