|
6 | 6 | from __future__ import annotations
|
7 | 7 |
|
8 | 8 | import os
|
9 |
| -import sys |
10 | 9 | import unittest
|
11 | 10 |
|
12 | 11 | import torch
|
|
15 | 14 | import onnxscript.testing
|
16 | 15 | from onnxscript import FLOAT, evaluator
|
17 | 16 | from onnxscript import opset18 as op
|
18 |
| -from onnxscript._internal import version_utils |
19 | 17 | from onnxscript.function_libs.torch_lib import graph_building, ops
|
20 | 18 |
|
21 | 19 | IS_WINDOWS = os.name == "nt"
|
@@ -157,79 +155,5 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel
|
157 | 155 | graph.add_initializer("x", x_tensor)
|
158 | 156 |
|
159 | 157 |
|
160 |
| -class _MLP(torch.nn.Module): |
161 |
| - def __init__(self, input_size, hidden_size, output_size): |
162 |
| - super().__init__() |
163 |
| - self.fc1 = torch.nn.Linear(input_size, hidden_size) |
164 |
| - self.fc2 = torch.nn.Linear(hidden_size, output_size) |
165 |
| - self.relu = torch.nn.ReLU() |
166 |
| - |
167 |
| - def forward(self, x): |
168 |
| - out = self.fc1(x) |
169 |
| - out = self.relu(out) |
170 |
| - out = self.fc2(out) |
171 |
| - return out |
172 |
| - |
173 |
| - |
174 |
| -@unittest.skipIf( |
175 |
| - IS_WINDOWS and version_utils.torch_older_than("2.3"), |
176 |
| - "dynamo_export not supported on Windows in PyTorch<2.3", |
177 |
| -) |
178 |
| -@unittest.skipIf( |
179 |
| - sys.version_info > (3, 11), |
180 |
| - "dynamo_export not supported due to torch.compile not functional for python>3.11", |
181 |
| -) |
182 |
| -class TestModelSaving(unittest.TestCase): |
183 |
| - def test_save_initializer_to_files_for_large_model(self): |
184 |
| - # # of model parameters: |
185 |
| - # input_size x hidden_size + hidden_size + |
186 |
| - # hidden_size x output_size + output_size |
187 |
| - # ~= 3GB below |
188 |
| - batch_size, input_size, hidden_size, output_size = 1, 4, 50000000, 10 |
189 |
| - model = _MLP(input_size, hidden_size, output_size) |
190 |
| - x = torch.randn(batch_size, input_size) |
191 |
| - |
192 |
| - model_proto = torch.onnx.dynamo_export(model, x).model_proto |
193 |
| - # Assert model is larger than 2GB (~=3GB) |
194 |
| - self.assertGreater(model_proto.ByteSize(), 2**31) |
195 |
| - |
196 |
| - def test_input_output_and_initializer_are_not_stored_in_value_info(self): |
197 |
| - batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 |
198 |
| - model = _MLP(input_size, hidden_size, output_size) |
199 |
| - x = torch.randn(batch_size, input_size) |
200 |
| - |
201 |
| - model_proto = torch.onnx.dynamo_export(model, x).model_proto |
202 |
| - v_names = {v.name for v in model_proto.graph.value_info} |
203 |
| - |
204 |
| - for i in model_proto.graph.input: |
205 |
| - self.assertNotIn(i.name, v_names) |
206 |
| - for o in model_proto.graph.output: |
207 |
| - self.assertNotIn(o.name, v_names) |
208 |
| - for i in model_proto.graph.initializer: |
209 |
| - self.assertNotIn(i.name, v_names) |
210 |
| - |
211 |
| - @unittest.skipIf( |
212 |
| - not version_utils.torch_older_than("2.4"), |
213 |
| - "PyTorch 2.4-preview optimizes the functions away", |
214 |
| - ) |
215 |
| - def test_experimental_function_value_info_are_stored_in_graph_value_info(self): |
216 |
| - batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 |
217 |
| - model = _MLP(input_size, hidden_size, output_size) |
218 |
| - x = torch.randn(batch_size, input_size) |
219 |
| - |
220 |
| - model_proto = torch.onnx.dynamo_export(model, x).model_proto |
221 |
| - v_names = {v.name for v in model_proto.graph.value_info} |
222 |
| - torch_functions = [ |
223 |
| - f for f in model_proto.functions if f.domain.startswith("pkg.torch") |
224 |
| - ] |
225 |
| - self.assertNotEqual(len(torch_functions), 0) |
226 |
| - for f in torch_functions: |
227 |
| - for n in f.node: |
228 |
| - for i in n.input: |
229 |
| - self.assertIn(f"{f.domain}::{f.name}/{i}", v_names) |
230 |
| - for o in n.output: |
231 |
| - self.assertIn(f"{f.domain}::{f.name}/{o}", v_names) |
232 |
| - |
233 |
| - |
234 | 158 | if __name__ == "__main__":
|
235 | 159 | unittest.main()
|
0 commit comments