@@ -142,53 +142,41 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel
142
142
graph .add_initializer ("x" , x_tensor )
143
143
144
144
145
+ class _MLP (torch .nn .Module ):
146
+ def __init__ (self , input_size , hidden_size , output_size ):
147
+ super ().__init__ ()
148
+ self .fc1 = torch .nn .Linear (input_size , hidden_size )
149
+ self .fc2 = torch .nn .Linear (hidden_size , output_size )
150
+ self .relu = torch .nn .ReLU ()
151
+
152
+ def forward (self , x ):
153
+ out = self .fc1 (x )
154
+ out = self .relu (out )
155
+ out = self .fc2 (out )
156
+ return out
157
+
158
+
145
159
@unittest .skipIf (
146
160
IS_WINDOWS and version_utils .torch_older_than ("2.3" ),
147
161
"dynamo_export not supported on Windows in PyTorch<2.3" ,
148
162
)
149
163
class TestModelSaving (unittest .TestCase ):
150
164
def test_save_initializer_to_files_for_large_model (self ):
151
- class MLP (torch .nn .Module ):
152
- def __init__ (self , input_size , hidden_size , output_size ):
153
- super ().__init__ ()
154
- self .fc1 = torch .nn .Linear (input_size , hidden_size )
155
- self .fc2 = torch .nn .Linear (hidden_size , output_size )
156
- self .relu = torch .nn .ReLU ()
157
-
158
- def forward (self , x ):
159
- out = self .fc1 (x )
160
- out = self .relu (out )
161
- out = self .fc2 (out )
162
- return out
163
-
164
165
# # of model parameters:
165
166
# input_size x hidden_size + hidden_size +
166
167
# hidden_size x output_size + output_size
167
168
# ~= 3GB below
168
169
batch_size , input_size , hidden_size , output_size = 1 , 4 , 50000000 , 10
169
- model = MLP (input_size , hidden_size , output_size )
170
+ model = _MLP (input_size , hidden_size , output_size )
170
171
x = torch .randn (batch_size , input_size )
171
172
172
173
model_proto = torch .onnx .dynamo_export (model , x ).model_proto
173
174
# Assert model is larger than 2GB (~=3GB)
174
175
self .assertGreater (model_proto .ByteSize (), 2 ** 31 )
175
176
176
177
def test_input_output_and_initializer_are_not_stored_in_value_info (self ):
177
- class MLP (torch .nn .Module ):
178
- def __init__ (self , input_size , hidden_size , output_size ):
179
- super ().__init__ ()
180
- self .fc1 = torch .nn .Linear (input_size , hidden_size )
181
- self .fc2 = torch .nn .Linear (hidden_size , output_size )
182
- self .relu = torch .nn .ReLU ()
183
-
184
- def forward (self , x ):
185
- out = self .fc1 (x )
186
- out = self .relu (out )
187
- out = self .fc2 (out )
188
- return out
189
-
190
178
batch_size , input_size , hidden_size , output_size = 1 , 4 , 5 , 10
191
- model = MLP (input_size , hidden_size , output_size )
179
+ model = _MLP (input_size , hidden_size , output_size )
192
180
x = torch .randn (batch_size , input_size )
193
181
194
182
model_proto = torch .onnx .dynamo_export (model , x ).model_proto
@@ -201,6 +189,24 @@ def forward(self, x):
201
189
for i in model_proto .graph .initializer :
202
190
self .assertNotIn (i .name , v_names )
203
191
192
+ def test_experimental_function_value_info_are_stored_in_graph_value_info (self ):
193
+ batch_size , input_size , hidden_size , output_size = 1 , 4 , 5 , 10
194
+ model = _MLP (input_size , hidden_size , output_size )
195
+ x = torch .randn (batch_size , input_size )
196
+
197
+ model_proto = torch .onnx .dynamo_export (model , x ).model_proto
198
+ v_names = {v .name for v in model_proto .graph .value_info }
199
+ torch_functions = [
200
+ f for f in model_proto .functions if f .domain .startswith ("pkg.torch" )
201
+ ]
202
+ self .assertNotEqual (len (torch_functions ), 0 )
203
+ for f in torch_functions :
204
+ for n in f .node :
205
+ for i in n .input :
206
+ self .assertIn (f"{ f .domain } ::{ f .name } /{ i } " , v_names )
207
+ for o in n .output :
208
+ self .assertIn (f"{ f .domain } ::{ f .name } /{ o } " , v_names )
209
+
204
210
205
211
if __name__ == "__main__" :
206
212
unittest .main ()
0 commit comments