13
13
class RemoveUnusedTest (unittest .TestCase ):
14
14
using_ir : bool
15
15
16
- def remove_unused_nodes (
17
- self , model : onnx .ModelProto , remove_initialized_inputs : bool = False
18
- ):
16
+ def remove_unused_nodes (self , model : onnx .ModelProto ):
19
17
if self .using_ir :
20
18
model_ir = ir .serde .deserialize_model (model )
21
- onnxscript .optimizer .remove_unused_nodes (model_ir , remove_initialized_inputs )
19
+ onnxscript .optimizer .remove_unused_nodes (model_ir )
22
20
model = ir .serde .serialize_model (model_ir )
23
21
return model
24
- onnxscript .optimizer .remove_unused_nodes (model , remove_initialized_inputs )
22
+ onnxscript .optimizer .remove_unused_nodes (model )
25
23
return model
26
24
27
25
def test_remove_unused_nodes (self ):
@@ -56,24 +54,7 @@ def test_remove_unused_initializers(self):
56
54
self .assertEqual (model .graph .node [0 ].op_type , "Mul" )
57
55
self .assertEqual (len (model .graph .initializer ), 0 )
58
56
59
- def test_unused_initialized_inputs_are_removed_when_requested (self ):
60
- # https://github.com/microsoft/onnxscript/issues/2211
61
- model = onnx .parser .parse_model (
62
- """
63
- <ir_version: 10, opset_import: [ "" : 17]>
64
- agraph (float[N] x, float[N] two) => (float[N] z)
65
- <float two = {2.0,2.0}> {
66
- four = Add(two, two)
67
- z = Mul(x, x)
68
- }
69
- """
70
- )
71
- model = self .remove_unused_nodes (model , remove_initialized_inputs = True )
72
- self .assertEqual (len (model .graph .node ), 1 )
73
- self .assertEqual (model .graph .node [0 ].op_type , "Mul" )
74
- self .assertEqual (len (model .graph .input ), 1 )
75
-
76
- def test_unused_initialized_inputs_are_kept_by_default (self ):
57
+ def test_unused_initialized_inputs_are_kept (self ):
77
58
model = onnx .parser .parse_model (
78
59
"""
79
60
<ir_version: 10, opset_import: [ "" : 17]>
@@ -88,9 +69,9 @@ def test_unused_initialized_inputs_are_kept_by_default(self):
88
69
self .assertEqual (len (model .graph .node ), 1 )
89
70
self .assertEqual (model .graph .node [0 ].op_type , "Mul" )
90
71
self .assertEqual (len (model .graph .input ), 2 )
72
+ self .assertEqual (len (model .graph .initializer ), 1 )
91
73
92
- @parameterized .parameterized .expand ([True , False ])
93
- def test_unused_inputs_are_not_removed (self , remove_initialized_inputs : bool ):
74
+ def test_unused_inputs_are_not_removed (self ):
94
75
# preserve inputs as part of interface
95
76
model = onnx .parser .parse_model (
96
77
"""
@@ -102,9 +83,7 @@ def test_unused_inputs_are_not_removed(self, remove_initialized_inputs: bool):
102
83
}
103
84
"""
104
85
)
105
- model = self .remove_unused_nodes (
106
- model , remove_initialized_inputs = remove_initialized_inputs
107
- )
86
+ model = self .remove_unused_nodes (model )
108
87
self .assertEqual (len (model .graph .node ), 1 )
109
88
self .assertEqual (model .graph .node [0 ].op_type , "Mul" )
110
89
self .assertEqual (len (model .graph .input ), 2 )
0 commit comments