diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 112bf2be45..de4446bd62 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -93,10 +93,27 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph class RemoveUnusedNodesPass(ir.passes.InPlacePass): + """Pass for removing unused nodes and initializers. + + Attributes: + remove_initialized_inputs: When an unused initializer is simultaneously a graph input, + remove that input as well. Note that this will change the model input signature. + """ + + def __init__(self, remove_initialized_inputs: bool = False): + super().__init__() + self.remove_initialized_inputs = remove_initialized_inputs + def call(self, model: ir.Model) -> ir.passes.PassResult: count = _remove_unused_nodes_in_graph_like(model.graph) graph_outputs = frozenset(model.graph.outputs) initializers = model.graph.initializers + if self.remove_initialized_inputs: + graph_inputs = model.graph.inputs + for i, inp in reversed(list(enumerate(graph_inputs))): + if inp.name in initializers and not (inp in graph_outputs or inp.uses()): + del graph_inputs[i] + count += 1 for init in list(initializers.values()): if not (init in graph_outputs or init.uses()): assert init.name is not None diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 664b36577c..d0a27626ed 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -13,13 +13,15 @@ class RemoveUnusedTest(unittest.TestCase): using_ir: bool - def remove_unused_nodes(self, model: onnx.ModelProto): + def remove_unused_nodes( + self, model: onnx.ModelProto, remove_initialized_inputs: bool = False + ): if self.using_ir: model_ir = ir.serde.deserialize_model(model) - onnxscript.optimizer.remove_unused_nodes(model_ir) + onnxscript.optimizer.remove_unused_nodes(model_ir, remove_initialized_inputs) model = ir.serde.serialize_model(model_ir) return model - onnxscript.optimizer.remove_unused_nodes(model) + onnxscript.optimizer.remove_unused_nodes(model, remove_initialized_inputs) return model def test_remove_unused_nodes(self): @@ -54,6 +56,59 @@ def test_remove_unused_initializers(self): self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.initializer), 0) + def test_unused_initialized_inputs_are_removed_when_requested(self): + # https://github.com/microsoft/onnxscript/issues/2211 + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] two) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + model = self.remove_unused_nodes(model, remove_initialized_inputs=True) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + self.assertEqual(len(model.graph.input), 1) + + def test_unused_initialized_inputs_are_kept_by_default(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] two) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + model = self.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + self.assertEqual(len(model.graph.input), 2) + + @parameterized.parameterized.expand([True, False]) + def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool): + # preserve inputs as part of interface + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] two) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + model = self.remove_unused_nodes( + model, remove_initialized_inputs=remove_initialized_inputs + ) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + self.assertEqual(len(model.graph.input), 2) + def test_partially_used_nodes(self): model = onnx.parser.parse_model( """ diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index a6e8ea2fc5..7cb0653a05 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -112,15 +112,19 @@ def fold_constants( return result -def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: +def remove_unused_nodes( + model: ir.Model | onnx.ModelProto, remove_initialized_inputs: bool = False +) -> None: """Removes unused nodes from a model inplace.""" if isinstance(model, ir.Model): - onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model) + onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass( + remove_initialized_inputs=remove_initialized_inputs + )(model) else: model_ir = ir.serde.deserialize_model(model) - model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()( - model_ir - ).model + model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass( + remove_initialized_inputs=remove_initialized_inputs + )(model_ir).model new_proto = ir.serde.serialize_model(model_ir) model.Clear() model.CopyFrom(new_proto)