From 7a720bb0fa82232898e52f08dcd822268305cb1a Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Thu, 17 Apr 2025 13:09:23 +0300 Subject: [PATCH 01/28] Update unused_removal.py https://github.com/microsoft/onnxscript/issues/2211 --- onnxscript/ir/passes/common/unused_removal.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 112bf2be45..8772544fb1 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -102,6 +102,11 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: assert init.name is not None del initializers[init.name] count += 1 + graph_inputs = model.graph.inputs + for num, input in list( enumerate( graph_inputs ) )[::-1]: + if not (input in graph_outputs or input.uses()): + del graph_inputs[num] + count += 1 for function in model.functions.values(): count += _remove_unused_nodes_in_graph_like(function) if count: From 62458de0193739b9b680b75d6a3370408e24b2bf Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Thu, 17 Apr 2025 16:44:15 +0300 Subject: [PATCH 02/28] Update unused_removal.py Let's affect only inputs that are initializers --- onnxscript/ir/passes/common/unused_removal.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 8772544fb1..ca4944761c 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -97,16 +97,16 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: count = _remove_unused_nodes_in_graph_like(model.graph) graph_outputs = frozenset(model.graph.outputs) initializers = model.graph.initializers + graph_inputs = model.graph.inputs + for num, input in list( enumerate( graph_inputs ) )[::-1]: + if input.name in initializers and not (input in graph_outputs or input.uses()): + del graph_inputs[num] + count += 1 for init in list(initializers.values()): if not (init in graph_outputs or init.uses()): assert init.name is not None del initializers[init.name] count += 1 - graph_inputs = model.graph.inputs - for num, input in list( enumerate( graph_inputs ) )[::-1]: - if not (input in graph_outputs or input.uses()): - del graph_inputs[num] - count += 1 for function in model.functions.values(): count += _remove_unused_nodes_in_graph_like(function) if count: From dced0cf295f4f232242944ed585ca0b07d27859f Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Thu, 17 Apr 2025 18:33:18 +0300 Subject: [PATCH 03/28] Update unused_removal.py [::-1]->reversed --- onnxscript/ir/passes/common/unused_removal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index ca4944761c..29bc6543b2 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -98,7 +98,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: graph_outputs = frozenset(model.graph.outputs) initializers = model.graph.initializers graph_inputs = model.graph.inputs - for num, input in list( enumerate( graph_inputs ) )[::-1]: + for num, input in reversed(list(enumerate(graph_inputs))): if input.name in initializers and not (input in graph_outputs or input.uses()): del graph_inputs[num] count += 1 From 51cac0d06dc4e3c70388bc175e099ad47f24ff8b Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Thu, 17 Apr 2025 18:35:28 +0300 Subject: [PATCH 04/28] Update unused_removal.py trailing spaces remove --- onnxscript/ir/passes/common/unused_removal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 29bc6543b2..4066b089e0 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -101,7 +101,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: for num, input in reversed(list(enumerate(graph_inputs))): if input.name in initializers and not (input in graph_outputs or input.uses()): del graph_inputs[num] - count += 1 + count += 1 for init in list(initializers.values()): if not (init in graph_outputs or init.uses()): assert init.name is not None From d20287bc4edee749029cf887530ea450c714f17c Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Thu, 17 Apr 2025 18:50:25 +0300 Subject: [PATCH 05/28] Update onnxscript/ir/passes/common/unused_removal.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/unused_removal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 4066b089e0..0298c820d9 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -98,9 +98,9 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: graph_outputs = frozenset(model.graph.outputs) initializers = model.graph.initializers graph_inputs = model.graph.inputs - for num, input in reversed(list(enumerate(graph_inputs))): + for i, input in reversed(list(enumerate(graph_inputs))): if input.name in initializers and not (input in graph_outputs or input.uses()): - del graph_inputs[num] + del graph_inputs[i] count += 1 for init in list(initializers.values()): if not (init in graph_outputs or init.uses()): From f6c3eb7846978b50c72f872f651f506876675e6a Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Fri, 18 Apr 2025 18:04:17 +0300 Subject: [PATCH 06/28] Update unused_removal.py RemoveUnusedNodesPass + parameter remove_initialized_inputs --- onnxscript/ir/passes/common/unused_removal.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 0298c820d9..4305c738ad 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -93,15 +93,26 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph class RemoveUnusedNodesPass(ir.passes.InPlacePass): + def __init__(self, remove_initialized_inputs: bool =True ): + """ + :param remove_initialized_inputs: if `True` (default) remove unused inputs, in case + where is corresponding initializer, (those are typically rather initializers than inputs) + if changed to `False`, unused inputs remain, even if it has default initializer + Note: usual inputs will remain anyhow + """ + super().__init__() + self.remove_initialized_inputs = remove_initialized_inputs + def call(self, model: ir.Model) -> ir.passes.PassResult: count = _remove_unused_nodes_in_graph_like(model.graph) graph_outputs = frozenset(model.graph.outputs) initializers = model.graph.initializers - graph_inputs = model.graph.inputs - for i, input in reversed(list(enumerate(graph_inputs))): - if input.name in initializers and not (input in graph_outputs or input.uses()): - del graph_inputs[i] - count += 1 + if self.remove_initialized_inputs: + graph_inputs = model.graph.inputs + for i, input in reversed(list(enumerate(graph_inputs))): + if input.name in initializers and not (input in graph_outputs or input.uses()): + del graph_inputs[i] + count += 1 for init in list(initializers.values()): if not (init in graph_outputs or init.uses()): assert init.name is not None From d0cafdfc197260676f2691e25fe35433e09b1e7a Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Fri, 18 Apr 2025 18:16:06 +0300 Subject: [PATCH 07/28] Update __init__.py remove_unused_noodles + parameter remove_initialized_inputs --- onnxscript/optimizer/__init__.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index b073b3345e..3afb3b462d 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -50,15 +50,19 @@ def fold_constants( return legacy_constant_folding.fold_constants(model, *args, **kwargs) -def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None: +def remove_unused_nodes(model: ir.Model | onnx.ModelProto, + remove_initialized_inputs: bool=True + ) -> None: """Removes unused nodes from a model inplace.""" if isinstance(model, ir.Model): - onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()(model) + onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass( + remove_initialized_inputs=remove_initialized_inputs + )(model) else: model_ir = ir.serde.deserialize_model(model) - model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass()( - model_ir - ).model + model_ir = onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass( + remove_initialized_inputs=remove_initialized_inputs + )(model_ir).model new_proto = ir.serde.serialize_model(model_ir) model.Clear() model.CopyFrom(new_proto) From e07b2488ab53b9436695c89c19584c1a43d84527 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Fri, 18 Apr 2025 18:34:54 +0300 Subject: [PATCH 08/28] Update unused_removal_test.py tests for issue #2211 --- .../ir/passes/common/unused_removal_test.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 664b36577c..1c3d6c60b0 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -54,6 +54,62 @@ def test_remove_unused_initializers(self): self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.initializer), 0) +def test_remove_unused_inputs_initializers(): + # remove inputs in case they are initializers + # https://github.com/microsoft/onnxscript/issues/2211 + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] two) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + ir_model = onnxscript.ir.serde.deserialize_model(model) + remove_unused_nodes(ir_model) + assert (len(ir_model.graph._nodes)== 1) + assert (len(ir_model.graph.inputs)== 1) + assert (ir_model.graph.node(0).op_type== "Mul") + +def test_avoid_remove_unused_inputs_initializers(): + # supress remove inputs in case they are initializers + # if explicitly said + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] two) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + ir_model = onnxscript.ir.serde.deserialize_model(model) + remove_unused_nodes(ir_model,False) + assert (len(ir_model.graph._nodes)== 1) + assert (len(ir_model.graph.inputs)== 2) + assert (ir_model.graph.node(0).op_type== "Mul") + +def test_avoid_remove_unused_inputs(): + # preserve inputs as part of interface + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] two) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + ir_model = onnxscript.ir.serde.deserialize_model(model) + remove_unused_nodes(ir_model) + assert (len(ir_model.graph._nodes)== 1) + assert (len(ir_model.graph.inputs)== 2) + assert (ir_model.graph.node(0).op_type== "Mul") + def test_partially_used_nodes(self): model = onnx.parser.parse_model( """ From 14c918e0aaa373cfb48598d09545efff0fb0ec28 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Mon, 21 Apr 2025 12:26:10 +0300 Subject: [PATCH 09/28] Update onnxscript/ir/passes/common/unused_removal.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/unused_removal.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 4305c738ad..b873410987 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -93,13 +93,13 @@ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph class RemoveUnusedNodesPass(ir.passes.InPlacePass): + """Pass for removing unused nodes and initializers. + + Attributes: + remove_initialized_inputs: When an unused initializer is simultaneously a graph input, + remove that input as well. Note that this will change the model input signature. + """ def __init__(self, remove_initialized_inputs: bool =True ): - """ - :param remove_initialized_inputs: if `True` (default) remove unused inputs, in case - where is corresponding initializer, (those are typically rather initializers than inputs) - if changed to `False`, unused inputs remain, even if it has default initializer - Note: usual inputs will remain anyhow - """ super().__init__() self.remove_initialized_inputs = remove_initialized_inputs From 30ad8bd22db69ce14fbd8b97bb2fe52b6098d262 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Mon, 21 Apr 2025 12:30:09 +0300 Subject: [PATCH 10/28] Update unused_removal.py For API compatibility! --- onnxscript/ir/passes/common/unused_removal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index b873410987..73edad7454 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -99,7 +99,7 @@ class RemoveUnusedNodesPass(ir.passes.InPlacePass): remove_initialized_inputs: When an unused initializer is simultaneously a graph input, remove that input as well. Note that this will change the model input signature. """ - def __init__(self, remove_initialized_inputs: bool =True ): + def __init__(self, remove_initialized_inputs: bool =False ): super().__init__() self.remove_initialized_inputs = remove_initialized_inputs From b3f9cdea12e25e34f7ac218fb1ae3be86ff68345 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Mon, 21 Apr 2025 12:31:10 +0300 Subject: [PATCH 11/28] Update __init__.py For API compatibility! --- onnxscript/optimizer/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 3afb3b462d..2655626eb9 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -51,7 +51,7 @@ def fold_constants( def remove_unused_nodes(model: ir.Model | onnx.ModelProto, - remove_initialized_inputs: bool=True + remove_initialized_inputs: bool=False ) -> None: """Removes unused nodes from a model inplace.""" if isinstance(model, ir.Model): From b00036c8584cd69ec0ad6ff4737b2200dad05369 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Mon, 21 Apr 2025 12:56:09 +0300 Subject: [PATCH 12/28] Update unused_removal_test.py errors fix --- .../ir/passes/common/unused_removal_test.py | 110 +++++++++--------- 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 1c3d6c60b0..ee66da93c0 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -13,13 +13,13 @@ class RemoveUnusedTest(unittest.TestCase): using_ir: bool - def remove_unused_nodes(self, model: onnx.ModelProto): + def remove_unused_nodes(self, model: onnx.ModelProto, remove_initialized_inputs: bool=False): if self.using_ir: model_ir = ir.serde.deserialize_model(model) - onnxscript.optimizer.remove_unused_nodes(model_ir) + onnxscript.optimizer.remove_unused_nodes(model_ir, remove_initialized_inputs) model = ir.serde.serialize_model(model_ir) return model - onnxscript.optimizer.remove_unused_nodes(model) + onnxscript.optimizer.remove_unused_nodes(model, remove_initialized_inputs) return model def test_remove_unused_nodes(self): @@ -54,61 +54,61 @@ def test_remove_unused_initializers(self): self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.initializer), 0) -def test_remove_unused_inputs_initializers(): - # remove inputs in case they are initializers - # https://github.com/microsoft/onnxscript/issues/2211 - model = onnx.parser.parse_model( + def test_remove_unused_inputs_initializers(): + # remove inputs in case they are initializers + # if explicitly said + # https://github.com/microsoft/onnxscript/issues/2211 + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] two) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } """ - - agraph (float[N] x, float[N] two) => (float[N] z) - { - four = Add(two, two) - z = Mul(x, x) - } - """ - ) - ir_model = onnxscript.ir.serde.deserialize_model(model) - remove_unused_nodes(ir_model) - assert (len(ir_model.graph._nodes)== 1) - assert (len(ir_model.graph.inputs)== 1) - assert (ir_model.graph.node(0).op_type== "Mul") - -def test_avoid_remove_unused_inputs_initializers(): - # supress remove inputs in case they are initializers - # if explicitly said - model = onnx.parser.parse_model( + ) + ir_model = onnxscript.ir.serde.deserialize_model(model) + ir_model = self.remove_unused_nodes(ir_model,True) + assert (len(ir_model.graph._nodes)== 1) + assert (len(ir_model.graph.inputs)== 1) + assert (ir_model.graph.node(0).op_type== "Mul") + + def test_avoid_remove_unused_inputs_initializers(): + # supress remove inputs in case they are initializers until explicitly said + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] two) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } """ - - agraph (float[N] x, float[N] two) => (float[N] z) - { - four = Add(two, two) - z = Mul(x, x) - } - """ - ) - ir_model = onnxscript.ir.serde.deserialize_model(model) - remove_unused_nodes(ir_model,False) - assert (len(ir_model.graph._nodes)== 1) - assert (len(ir_model.graph.inputs)== 2) - assert (ir_model.graph.node(0).op_type== "Mul") - -def test_avoid_remove_unused_inputs(): - # preserve inputs as part of interface - model = onnx.parser.parse_model( + ) + ir_model = onnxscript.ir.serde.deserialize_model(model) + ir_model = self.remove_unused_nodes(ir_model) + assert (len(ir_model.graph._nodes)== 1) + assert (len(ir_model.graph.inputs)== 2) + assert (ir_model.graph.node(0).op_type== "Mul") + + def test_avoid_remove_unused_inputs(): + # preserve inputs as part of interface + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] two) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } """ - - agraph (float[N] x, float[N] two) => (float[N] z) - { - four = Add(two, two) - z = Mul(x, x) - } - """ - ) - ir_model = onnxscript.ir.serde.deserialize_model(model) - remove_unused_nodes(ir_model) - assert (len(ir_model.graph._nodes)== 1) - assert (len(ir_model.graph.inputs)== 2) - assert (ir_model.graph.node(0).op_type== "Mul") + ) + ir_model = onnxscript.ir.serde.deserialize_model(model) + ir_model = self.remove_unused_nodes(ir_model,True) + assert (len(ir_model.graph._nodes)== 1) + assert (len(ir_model.graph.inputs)== 2) + assert (ir_model.graph.node(0).op_type== "Mul") def test_partially_used_nodes(self): model = onnx.parser.parse_model( From 817709e87c43239f3014ef964eff73016b1814bf Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 15:56:59 +0300 Subject: [PATCH 13/28] Update onnxscript/ir/passes/common/unused_removal_test.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/unused_removal_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index ee66da93c0..d16b4b6e86 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -54,7 +54,7 @@ def test_remove_unused_initializers(self): self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.initializer), 0) - def test_remove_unused_inputs_initializers(): + def test_remove_unused_inputs_initializers(self): # remove inputs in case they are initializers # if explicitly said # https://github.com/microsoft/onnxscript/issues/2211 From 716750c645b9bd7ad84777f910ee2675f197e28f Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 16:13:33 +0300 Subject: [PATCH 14/28] Update unused_removal_test.py --- .../ir/passes/common/unused_removal_test.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index d16b4b6e86..ba6b7c9fb7 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -70,11 +70,11 @@ def test_remove_unused_inputs_initializers(self): ) ir_model = onnxscript.ir.serde.deserialize_model(model) ir_model = self.remove_unused_nodes(ir_model,True) - assert (len(ir_model.graph._nodes)== 1) - assert (len(ir_model.graph.inputs)== 1) - assert (ir_model.graph.node(0).op_type== "Mul") - - def test_avoid_remove_unused_inputs_initializers(): + self.assertEqual(len(ir_model.graph.node), 1) + self.assertEqual(ir_model.graph.node[0].op_type, "Mul") + self.assertEqual(len(ir_model.graph.inputs), 1) + + def test_avoid_remove_unused_inputs_initializers(self): # supress remove inputs in case they are initializers until explicitly said model = onnx.parser.parse_model( """ @@ -88,11 +88,11 @@ def test_avoid_remove_unused_inputs_initializers(): ) ir_model = onnxscript.ir.serde.deserialize_model(model) ir_model = self.remove_unused_nodes(ir_model) - assert (len(ir_model.graph._nodes)== 1) - assert (len(ir_model.graph.inputs)== 2) - assert (ir_model.graph.node(0).op_type== "Mul") - - def test_avoid_remove_unused_inputs(): + self.assertEqual(len(ir_model.graph.node), 1) + self.assertEqual(ir_model.graph.node[0].op_type, "Mul") + self.assertEqual(len(ir_model.graph.inputs), 2) + + def test_avoid_remove_unused_inputs(self): # preserve inputs as part of interface model = onnx.parser.parse_model( """ @@ -106,9 +106,9 @@ def test_avoid_remove_unused_inputs(): ) ir_model = onnxscript.ir.serde.deserialize_model(model) ir_model = self.remove_unused_nodes(ir_model,True) - assert (len(ir_model.graph._nodes)== 1) - assert (len(ir_model.graph.inputs)== 2) - assert (ir_model.graph.node(0).op_type== "Mul") + self.assertEqual(len(ir_model.graph.node), 1) + self.assertEqual(ir_model.graph.node[0].op_type, "Mul") + self.assertEqual(len(ir_model.graph.inputs), 2) def test_partially_used_nodes(self): model = onnx.parser.parse_model( From 03ddd34bf3b9f8056ff43c28da437e7a76212289 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 16:17:05 +0300 Subject: [PATCH 15/28] Update unused_removal.py one trailing white space removed --- onnxscript/ir/passes/common/unused_removal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 73edad7454..9d194512c3 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -97,7 +97,7 @@ class RemoveUnusedNodesPass(ir.passes.InPlacePass): Attributes: remove_initialized_inputs: When an unused initializer is simultaneously a graph input, - remove that input as well. Note that this will change the model input signature. + remove that input as well. Note that this will change the model input signature. """ def __init__(self, remove_initialized_inputs: bool =False ): super().__init__() From 4b138b090c9f11311481fe9f3e2a4143799fc21c Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 16:38:02 +0300 Subject: [PATCH 16/28] linter. hate it Meh --- onnxscript/ir/passes/common/unused_removal.py | 3 ++- onnxscript/ir/passes/common/unused_removal_test.py | 8 +++++--- onnxscript/optimizer/__init__.py | 6 +++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 9d194512c3..28f1bec837 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -99,7 +99,8 @@ class RemoveUnusedNodesPass(ir.passes.InPlacePass): remove_initialized_inputs: When an unused initializer is simultaneously a graph input, remove that input as well. Note that this will change the model input signature. """ - def __init__(self, remove_initialized_inputs: bool =False ): + + def __init__(self, remove_initialized_inputs: bool = False): super().__init__() self.remove_initialized_inputs = remove_initialized_inputs diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index ba6b7c9fb7..560e0b5563 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -13,7 +13,9 @@ class RemoveUnusedTest(unittest.TestCase): using_ir: bool - def remove_unused_nodes(self, model: onnx.ModelProto, remove_initialized_inputs: bool=False): + def remove_unused_nodes( + self, model: onnx.ModelProto, remove_initialized_inputs: bool = False + ): if self.using_ir: model_ir = ir.serde.deserialize_model(model) onnxscript.optimizer.remove_unused_nodes(model_ir, remove_initialized_inputs) @@ -69,7 +71,7 @@ def test_remove_unused_inputs_initializers(self): """ ) ir_model = onnxscript.ir.serde.deserialize_model(model) - ir_model = self.remove_unused_nodes(ir_model,True) + ir_model = self.remove_unused_nodes(ir_model, True) self.assertEqual(len(ir_model.graph.node), 1) self.assertEqual(ir_model.graph.node[0].op_type, "Mul") self.assertEqual(len(ir_model.graph.inputs), 1) @@ -105,7 +107,7 @@ def test_avoid_remove_unused_inputs(self): """ ) ir_model = onnxscript.ir.serde.deserialize_model(model) - ir_model = self.remove_unused_nodes(ir_model,True) + ir_model = self.remove_unused_nodes(ir_model, True) self.assertEqual(len(ir_model.graph.node), 1) self.assertEqual(ir_model.graph.node[0].op_type, "Mul") self.assertEqual(len(ir_model.graph.inputs), 2) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 2655626eb9..5e24300f6d 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -50,9 +50,9 @@ def fold_constants( return legacy_constant_folding.fold_constants(model, *args, **kwargs) -def remove_unused_nodes(model: ir.Model | onnx.ModelProto, - remove_initialized_inputs: bool=False - ) -> None: +def remove_unused_nodes( + model: ir.Model | onnx.ModelProto, remove_initialized_inputs: bool = False +) -> None: """Removes unused nodes from a model inplace.""" if isinstance(model, ir.Model): onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass( From f5ff16cd3e1f84ab516f6d3b569c167a036ae530 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 18:03:24 +0300 Subject: [PATCH 17/28] Update unused_removal_test.py change model to proto --- .../ir/passes/common/unused_removal_test.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 560e0b5563..1447430fdb 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -70,11 +70,10 @@ def test_remove_unused_inputs_initializers(self): } """ ) - ir_model = onnxscript.ir.serde.deserialize_model(model) - ir_model = self.remove_unused_nodes(ir_model, True) - self.assertEqual(len(ir_model.graph.node), 1) - self.assertEqual(ir_model.graph.node[0].op_type, "Mul") - self.assertEqual(len(ir_model.graph.inputs), 1) + model = self.remove_unused_nodes(model, True) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + self.assertEqual(len(model.graph.inputs), 1) def test_avoid_remove_unused_inputs_initializers(self): # supress remove inputs in case they are initializers until explicitly said @@ -88,11 +87,10 @@ def test_avoid_remove_unused_inputs_initializers(self): } """ ) - ir_model = onnxscript.ir.serde.deserialize_model(model) - ir_model = self.remove_unused_nodes(ir_model) - self.assertEqual(len(ir_model.graph.node), 1) - self.assertEqual(ir_model.graph.node[0].op_type, "Mul") - self.assertEqual(len(ir_model.graph.inputs), 2) + model = self.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + self.assertEqual(len(model.graph.inputs), 2) def test_avoid_remove_unused_inputs(self): # preserve inputs as part of interface @@ -106,11 +104,10 @@ def test_avoid_remove_unused_inputs(self): } """ ) - ir_model = onnxscript.ir.serde.deserialize_model(model) - ir_model = self.remove_unused_nodes(ir_model, True) - self.assertEqual(len(ir_model.graph.node), 1) - self.assertEqual(ir_model.graph.node[0].op_type, "Mul") - self.assertEqual(len(ir_model.graph.inputs), 2) + model = self.remove_unused_nodes(model, True) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + self.assertEqual(len(model.graph.inputs), 2) def test_partially_used_nodes(self): model = onnx.parser.parse_model( From 1b86d3755d4baa7013382b1779509c965dfc2891 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 18:31:11 +0300 Subject: [PATCH 18/28] Update unused_removal_test.py AAAA!!!! --- onnxscript/ir/passes/common/unused_removal_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 1447430fdb..de417dfce5 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -73,7 +73,7 @@ def test_remove_unused_inputs_initializers(self): model = self.remove_unused_nodes(model, True) self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "Mul") - self.assertEqual(len(model.graph.inputs), 1) + self.assertEqual(len(model.graph.input), 1) def test_avoid_remove_unused_inputs_initializers(self): # supress remove inputs in case they are initializers until explicitly said @@ -90,7 +90,7 @@ def test_avoid_remove_unused_inputs_initializers(self): model = self.remove_unused_nodes(model) self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "Mul") - self.assertEqual(len(model.graph.inputs), 2) + self.assertEqual(len(model.graph.input), 2) def test_avoid_remove_unused_inputs(self): # preserve inputs as part of interface @@ -107,7 +107,7 @@ def test_avoid_remove_unused_inputs(self): model = self.remove_unused_nodes(model, True) self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "Mul") - self.assertEqual(len(model.graph.inputs), 2) + self.assertEqual(len(model.graph.input), 2) def test_partially_used_nodes(self): model = onnx.parser.parse_model( From a434f3a08a4b873c86a0bf984d4463cc2b0bcb88 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 18:34:05 +0300 Subject: [PATCH 19/28] Update onnxscript/ir/passes/common/unused_removal_test.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/unused_removal_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index de417dfce5..562910da60 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -70,7 +70,7 @@ def test_remove_unused_inputs_initializers(self): } """ ) - model = self.remove_unused_nodes(model, True) + model = self.remove_unused_nodes(model, remove_initialized_inputs=True) self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.input), 1) From e34de7206c197bf33e067756969cb65de063208e Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 18:35:41 +0300 Subject: [PATCH 20/28] Update unused_removal_test.py --- onnxscript/ir/passes/common/unused_removal_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 562910da60..309c31a0f6 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -104,7 +104,7 @@ def test_avoid_remove_unused_inputs(self): } """ ) - model = self.remove_unused_nodes(model, True) + model = self.remove_unused_nodes(model, remove_initialized_inputs=True) self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.input), 2) From 39d697945c76e053ade611df6a368c8638937607 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 18:36:57 +0300 Subject: [PATCH 21/28] Update onnxscript/ir/passes/common/unused_removal_test.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/unused_removal_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 309c31a0f6..a0226fa5fa 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -56,7 +56,7 @@ def test_remove_unused_initializers(self): self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.initializer), 0) - def test_remove_unused_inputs_initializers(self): + def test_unused_initialized_inputs_are_removed_when_requested(self): # remove inputs in case they are initializers # if explicitly said # https://github.com/microsoft/onnxscript/issues/2211 From 8407a907b7e782ca3c0edbee6a076d4c572a7164 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 18:37:28 +0300 Subject: [PATCH 22/28] Update onnxscript/ir/passes/common/unused_removal_test.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/unused_removal_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index a0226fa5fa..63cb34450e 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -57,8 +57,6 @@ def test_remove_unused_initializers(self): self.assertEqual(len(model.graph.initializer), 0) def test_unused_initialized_inputs_are_removed_when_requested(self): - # remove inputs in case they are initializers - # if explicitly said # https://github.com/microsoft/onnxscript/issues/2211 model = onnx.parser.parse_model( """ From eb0c8efe8fb39f437f254443d5e227fc639ee702 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 18:37:46 +0300 Subject: [PATCH 23/28] Update onnxscript/ir/passes/common/unused_removal_test.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/unused_removal_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 63cb34450e..990555129f 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -73,7 +73,7 @@ def test_unused_initialized_inputs_are_removed_when_requested(self): self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.input), 1) - def test_avoid_remove_unused_inputs_initializers(self): + def test_unused_initialized_inputs_are_kept_by_default(self): # supress remove inputs in case they are initializers until explicitly said model = onnx.parser.parse_model( """ From dfa809f55580eee76a3a44aa43e21e29ce5615e7 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 18:38:00 +0300 Subject: [PATCH 24/28] Update onnxscript/ir/passes/common/unused_removal_test.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/unused_removal_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 990555129f..999659a45c 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -74,7 +74,6 @@ def test_unused_initialized_inputs_are_removed_when_requested(self): self.assertEqual(len(model.graph.input), 1) def test_unused_initialized_inputs_are_kept_by_default(self): - # supress remove inputs in case they are initializers until explicitly said model = onnx.parser.parse_model( """ From 84b12cf0afd6ba84f06eb4c0d5bb6d6dbc94ad99 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 18:38:41 +0300 Subject: [PATCH 25/28] Update onnxscript/ir/passes/common/unused_removal_test.py Co-authored-by: Justin Chu --- onnxscript/ir/passes/common/unused_removal_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 999659a45c..64a3ed689d 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -89,7 +89,8 @@ def test_unused_initialized_inputs_are_kept_by_default(self): self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.input), 2) - def test_avoid_remove_unused_inputs(self): + @parameterized.parameterized.expand([True, False]) + def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool): # preserve inputs as part of interface model = onnx.parser.parse_model( """ From 6b689ef822d7c30935139639e1326dbc7543c03e Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 19:19:14 +0300 Subject: [PATCH 26/28] Update unused_removal_test.py parametrization --- onnxscript/ir/passes/common/unused_removal_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 64a3ed689d..45556c5a74 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -102,7 +102,7 @@ def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool): } """ ) - model = self.remove_unused_nodes(model, remove_initialized_inputs=True) + model = self.remove_unused_nodes(model, remove_initialized_inputs=remove_initialized_inputs) self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.input), 2) From 866fbdc18946f6409868bee0de159888a37f4014 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 20:36:33 +0300 Subject: [PATCH 27/28] Update unused_removal_test.py --- onnxscript/ir/passes/common/unused_removal_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py index 45556c5a74..d0a27626ed 100644 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ b/onnxscript/ir/passes/common/unused_removal_test.py @@ -102,7 +102,9 @@ def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool): } """ ) - model = self.remove_unused_nodes(model, remove_initialized_inputs=remove_initialized_inputs) + model = self.remove_unused_nodes( + model, remove_initialized_inputs=remove_initialized_inputs + ) self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "Mul") self.assertEqual(len(model.graph.input), 2) From d12b6b24a17af8b9f8d264261e37564a73d470a2 Mon Sep 17 00:00:00 2001 From: Alexey Biryukov Date: Tue, 22 Apr 2025 21:52:34 +0300 Subject: [PATCH 28/28] Update unused_removal.py --- onnxscript/ir/passes/common/unused_removal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py index 28f1bec837..de4446bd62 100644 --- a/onnxscript/ir/passes/common/unused_removal.py +++ b/onnxscript/ir/passes/common/unused_removal.py @@ -110,8 +110,8 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: initializers = model.graph.initializers if self.remove_initialized_inputs: graph_inputs = model.graph.inputs - for i, input in reversed(list(enumerate(graph_inputs))): - if input.name in initializers and not (input in graph_outputs or input.uses()): + for i, inp in reversed(list(enumerate(graph_inputs))): + if inp.name in initializers and not (inp in graph_outputs or inp.uses()): del graph_inputs[i] count += 1 for init in list(initializers.values()):