Skip to content

Commit d0b0c34

Browse files
committed
rebase with regression fix on "[Migration][DO NOT MERGE] Separate old ir into _legacy_ir folder"
All tests except for linter are expected to pass. [ghstack-poisoned]
2 parents d5ab484 + 290d21e commit d0b0c34

File tree

2 files changed

+38
-30
lines changed

2 files changed

+38
-30
lines changed

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,9 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto):
822822
# nn.Modules exported by dynamo exporter have unique call sites, their function
823823
# op_type name can serve to form the unique identifier for value info.
824824
# Store inside top level GraphProto.
825-
new_value_info = self.generate_maingraph_value_info_proto()
825+
new_value_info = self.generate_subgraphs_value_info_proto()
826+
# Insert value info for nodes in top level graph.
827+
new_value_info.update(self.generate_maingraph_value_info_proto())
826828
# Do not store input, output or initializer into value_info
827829
for input in onnx_model.graph.input:
828830
new_value_info.pop(input.name, None)
@@ -908,7 +910,7 @@ def generate_function_value_info_proto(
908910
return named_value_info
909911

910912
@runtime_typing.checked
911-
def generate_subgraphs_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProto]:
913+
def generate_subgraphs_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]:
912914
"""Unique naming strategies for values inside subgraphs, i.e. local functions.
913915
914916
{function_domain::function_op_type}/{value_name}

onnxscript/function_libs/torch_lib/graph_building_test.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -142,53 +142,41 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel
142142
graph.add_initializer("x", x_tensor)
143143

144144

145+
class _MLP(torch.nn.Module):
146+
def __init__(self, input_size, hidden_size, output_size):
147+
super().__init__()
148+
self.fc1 = torch.nn.Linear(input_size, hidden_size)
149+
self.fc2 = torch.nn.Linear(hidden_size, output_size)
150+
self.relu = torch.nn.ReLU()
151+
152+
def forward(self, x):
153+
out = self.fc1(x)
154+
out = self.relu(out)
155+
out = self.fc2(out)
156+
return out
157+
158+
145159
@unittest.skipIf(
146160
IS_WINDOWS and version_utils.torch_older_than("2.3"),
147161
"dynamo_export not supported on Windows in PyTorch<2.3",
148162
)
149163
class TestModelSaving(unittest.TestCase):
150164
def test_save_initializer_to_files_for_large_model(self):
151-
class MLP(torch.nn.Module):
152-
def __init__(self, input_size, hidden_size, output_size):
153-
super().__init__()
154-
self.fc1 = torch.nn.Linear(input_size, hidden_size)
155-
self.fc2 = torch.nn.Linear(hidden_size, output_size)
156-
self.relu = torch.nn.ReLU()
157-
158-
def forward(self, x):
159-
out = self.fc1(x)
160-
out = self.relu(out)
161-
out = self.fc2(out)
162-
return out
163-
164165
# # of model parameters:
165166
# input_size x hidden_size + hidden_size +
166167
# hidden_size x output_size + output_size
167168
# ~= 3GB below
168169
batch_size, input_size, hidden_size, output_size = 1, 4, 50000000, 10
169-
model = MLP(input_size, hidden_size, output_size)
170+
model = _MLP(input_size, hidden_size, output_size)
170171
x = torch.randn(batch_size, input_size)
171172

172173
model_proto = torch.onnx.dynamo_export(model, x).model_proto
173174
# Assert model is larger than 2GB (~=3GB)
174175
self.assertGreater(model_proto.ByteSize(), 2**31)
175176

176177
def test_input_output_and_initializer_are_not_stored_in_value_info(self):
177-
class MLP(torch.nn.Module):
178-
def __init__(self, input_size, hidden_size, output_size):
179-
super().__init__()
180-
self.fc1 = torch.nn.Linear(input_size, hidden_size)
181-
self.fc2 = torch.nn.Linear(hidden_size, output_size)
182-
self.relu = torch.nn.ReLU()
183-
184-
def forward(self, x):
185-
out = self.fc1(x)
186-
out = self.relu(out)
187-
out = self.fc2(out)
188-
return out
189-
190178
batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10
191-
model = MLP(input_size, hidden_size, output_size)
179+
model = _MLP(input_size, hidden_size, output_size)
192180
x = torch.randn(batch_size, input_size)
193181

194182
model_proto = torch.onnx.dynamo_export(model, x).model_proto
@@ -201,6 +189,24 @@ def forward(self, x):
201189
for i in model_proto.graph.initializer:
202190
self.assertNotIn(i.name, v_names)
203191

192+
def test_experimental_function_value_info_are_stored_in_graph_value_info(self):
193+
batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10
194+
model = _MLP(input_size, hidden_size, output_size)
195+
x = torch.randn(batch_size, input_size)
196+
197+
model_proto = torch.onnx.dynamo_export(model, x).model_proto
198+
v_names = {v.name for v in model_proto.graph.value_info}
199+
torch_functions = [
200+
f for f in model_proto.functions if f.domain.startswith("pkg.torch")
201+
]
202+
self.assertNotEqual(len(torch_functions), 0)
203+
for f in torch_functions:
204+
for n in f.node:
205+
for i in n.input:
206+
self.assertIn(f"{f.domain}::{f.name}/{i}", v_names)
207+
for o in n.output:
208+
self.assertIn(f"{f.domain}::{f.name}/{o}", v_names)
209+
204210

205211
if __name__ == "__main__":
206212
unittest.main()

0 commit comments

Comments
 (0)