Skip to content

Commit ce90e25

Browse files
authored
Fixing issues encountered while trying to to-onnx applecider models. (#565)
* Fixing issues encountered while trying to to-onnx applecider models. * Move set_default_device to a function that is called once. * Bug fixes in `engine`. Correctly import InferenceDataSetWriter. Fix way we write onnx_results. * Copy the to_tensor.py file to the onnx output directory. * Adding initial version of dynamo-based torch onnx export.
1 parent eca88a7 commit ce90e25

File tree

5 files changed

+118
-19
lines changed

5 files changed

+118
-19
lines changed

src/hyrax/model_exporters.py

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,21 @@ def export_to_onnx(model, sample, config, ctx):
7878

7979

8080
def _export_pytorch_to_onnx(model, sample, output_filepath, opset_version):
81-
"""Specific implementation to convert PyTorch model to ONNX format. This
82-
function will also:
83-
- Run `sample` through the model before converting the model to ONNX
84-
- Convert `sample` to a numpy array
81+
"""Specific implementation to convert PyTorch model to ONNX format. This uses
82+
the older (torch<2.9) export capabilities. And only supports up the opset
83+
version 20.
84+
85+
Parameters
86+
----------
87+
model : torch.nn.Module
88+
The PyTorch model to be converted to ONNX format.
89+
sample : NumPy array or list of NumPy arrays
90+
A sample of input data to the model. This is used to trace the model
91+
during the export process.
92+
output_filepath : Path
93+
The file path where the ONNX model will be saved.
94+
opset_version : int
95+
The ONNX opset version to use for the export.
8596
"""
8697

8798
# deferred import to reduce start up time
@@ -106,19 +117,23 @@ def _export_pytorch_to_onnx(model, sample, output_filepath, opset_version):
106117
input_names = []
107118
dynamic_axes = {}
108119

109-
# ! Currently `sample` is either a tuple or bare numpy array. But after it
110-
# ! goes through `default_convert` above, it becomes a list of Tensors.
120+
# torch_sample is returned from default_convert as either a single Tensor or
121+
# a list of Tensors.
111122
if isinstance(torch_sample, list):
112123
for i in range(len(torch_sample)):
113124
# For supervised models, the label or target should be empty
114-
# so we will not include those in the input names.
125+
# so we will not include those in the input names. Any labels should
126+
# be the last element in the list.
115127
if len(torch_sample[i]):
116128
input_names.append(f"input_{i}")
117129
dynamic_axes[f"input_{i}"] = {0: "batch_size"}
118130
else:
119131
input_names.append("input")
120132
dynamic_axes["input"] = {0: "batch_size"}
121133

134+
# Output is assumed to always have a dynamic batch size.
135+
dynamic_axes["output"] = {0: "batch_size"}
136+
122137
# export the model to ONNX format
123138
export(
124139
model,
@@ -128,11 +143,73 @@ def _export_pytorch_to_onnx(model, sample, output_filepath, opset_version):
128143
input_names=input_names,
129144
output_names=["output"],
130145
dynamic_axes=dynamic_axes,
146+
dynamo=False, # newer versions of torch will use dynamo by default
131147
)
132148

133149
# Make sure that the output is on the CPU
134150
if sample_out.device.type != "cpu":
135151
sample_out = sample_out.to("cpu")
136152

137153
# Return the output of the model as numpy array
138-
return sample_out.numpy()
154+
return sample_out.detach().numpy()
155+
156+
157+
def _export_pytorch_to_onnx_v2(model, sample, output_filepath, opset_version):
158+
"""Currently unused.
159+
Specific implementation to convert PyTorch model to ONNX format using
160+
torch Dynamo export capabilities.
161+
162+
Parameters
163+
----------
164+
model : torch.nn.Module
165+
The PyTorch model to be converted to ONNX format.
166+
sample : NumPy array or list of NumPy arrays
167+
A sample of input data to the model. This is used to trace the model
168+
during the export process.
169+
output_filepath : Path
170+
The file path where the ONNX model will be saved.
171+
opset_version : int
172+
The ONNX opset version to use for the export.
173+
"""
174+
175+
# deferred import to reduce start up time
176+
import torch
177+
from torch.onnx import export
178+
from torch.utils.data.dataloader import default_convert
179+
180+
# set model in eval mode and move it to the CPU to prep for export to ONNX.
181+
model.train(False)
182+
model.to("cpu")
183+
184+
# set the default device to CPU and convert the sample to torch Tensors
185+
torch.set_default_device("cpu")
186+
torch_sample = default_convert(sample)
187+
188+
# Run a single sample through the model. We'll check this against the output
189+
# from the ONNX version to make sure it's the same, i.e. `np.assert_allclose`.
190+
sample_out = model(torch_sample)
191+
# Make sure that the output is on the CPU, detached, and as a numpy array
192+
sample_out = sample_out.to("cpu").detach().numpy()
193+
194+
dynamic_shapes = []
195+
batch = torch.export.Dim("batch")
196+
197+
# TODO: This should be built dynamically based on the structure of torch_sample.
198+
dynamic_shapes = [[{0: batch}, {0: batch}, {}]]
199+
200+
export(
201+
model,
202+
(torch_sample,), # exporter expects a tuple of inputs for `forward`
203+
output_filepath,
204+
opset_version=opset_version,
205+
dynamo=True,
206+
dynamic_shapes=dynamic_shapes,
207+
verbose=True,
208+
report=True,
209+
dump_exported_program=True,
210+
artifacts_dir=output_filepath.parent,
211+
input_names=["input"],
212+
output_names=["output"],
213+
)
214+
215+
return sample_out

src/hyrax/pytorch_ignite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ def create_engine(funcname: str, device: torch.device, model: torch.nn.Module, c
548548
config : dict
549549
The runtime config in use
550550
"""
551+
torch.set_default_device(device.type)
551552
return Engine(_create_process_func(funcname, device, model, config))
552553

553554

src/hyrax/verbs/engine.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def run(self, model_directory: str = None):
3535
[x] Implement a simple strategy for reading in batches of data samples
3636
[x] Process the samples with any custom collate functions as well as a default collate function
3737
[x] Pass the collated batch to the appropriate to_tensor function
38-
[ ] Send that output to the ONNX-ified model
38+
[x] Send that output to the ONNX-ified model
3939
[x] Persist the results of inference.
4040
"""
4141
from pathlib import Path
@@ -46,7 +46,7 @@ def run(self, model_directory: str = None):
4646
create_results_dir,
4747
find_most_recent_results_dir,
4848
)
49-
from hyrax.data_sets.inference_dataset import InferenceDatasetWriter
49+
from hyrax.data_sets.inference_dataset import InferenceDataSetWriter
5050
from hyrax.plugin_utils import load_to_tensor
5151
from hyrax.pytorch_ignite import setup_dataset
5252

@@ -73,7 +73,7 @@ def run(self, model_directory: str = None):
7373
to_tensor_fn = load_to_tensor(input_directory)
7474

7575
# ~ Load the ONNX model from the input directory.
76-
onnx_file_name = input_directory / "model.onnx"
76+
onnx_file_name = input_directory / "example_model_opset_20.onnx"
7777
ort_session = onnxruntime.InferenceSession(onnx_file_name)
7878

7979
# ~ For now we use `setup_dataset` to get our datasets back. Later we can
@@ -94,7 +94,7 @@ def run(self, model_directory: str = None):
9494
# as a type hint. So we may need to separate InferenceDataset and IDWriter
9595
# to remove that dependency.
9696
result_dir = create_results_dir(config, "engine")
97-
self.results_writer = InferenceDatasetWriter(infer_dataset, result_dir)
97+
self.results_writer = InferenceDataSetWriter(infer_dataset, result_dir)
9898

9999
# Work through the dataset in steps of `batch_size`
100100
for start_idx in range(0, len(infer_dataset), batch_size):
@@ -108,9 +108,21 @@ def run(self, model_directory: str = None):
108108
# ~ Pass the collated batch to the to_tensor function
109109
prepared_batch = to_tensor_fn(collated_batch)
110110

111-
# Then we would send that output to the ONNX-ified model.
112-
ort_inputs = {ort_session.get_inputs()[0].name: prepared_batch}
113-
onnx_results = ort_session.run(None, ort_inputs) # infer with ONNX
111+
# Create the inputs array for the ONNX model using the expected inputs
112+
# from the loaded ONNX model and the type and shape of the prepared batch.
113+
ort_inputs = {}
114+
if isinstance(prepared_batch, tuple):
115+
for i in range(len(prepared_batch)):
116+
# For a supervised model, we expect that at least one of the
117+
# element in the prepared batch will be empty, so we only
118+
# add non-empty inputs.
119+
if len(prepared_batch[i]):
120+
ort_inputs[ort_session.get_inputs()[i].name] = prepared_batch[i]
121+
else:
122+
ort_inputs = {ort_session.get_inputs()[0].name: prepared_batch}
123+
124+
# Run the ONNX model with the prepared batch as input
125+
onnx_results = ort_session.run(None, ort_inputs)
114126

115127
# ~ Finally, we persist the results of inference.
116128
# For now, collated_batch will always have an "object_id" key that
@@ -122,8 +134,10 @@ def run(self, model_directory: str = None):
122134
msg += f"Could not determine object IDs from batch. Batch has keys {collated_batch.keys()}"
123135
raise RuntimeError(msg)
124136

125-
# ~ We may not need to do the list comprehension for batch_results, it's
126-
# possible that ONNX will already return it in this form.
127-
self.results_writer.write_batch(collated_batch["object_id"], [t for t in onnx_results])
137+
# Save the output of the onnx model per batch. Onnx results are
138+
# returned as a 1-element list containing a numpy array with first
139+
# dimension as batch size.
140+
self.results_writer.write_batch(collated_batch["object_id"], [i for i in onnx_results[0]])
128141

142+
# Write the final index file for the inference results.
129143
self.results_writer.write_index()

src/hyrax/verbs/to_onnx.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def run_cli(self, args=None):
2929

3030
def run(self, input_model_directory: str = None):
3131
"""Export the model to ONNX format and save it to the specified path."""
32+
import shutil
3233
from pathlib import Path
3334

3435
from hyrax.config_utils import (
@@ -66,6 +67,12 @@ def run(self, input_model_directory: str = None):
6667
config_manager = ConfigManager(runtime_config_filepath=config_file)
6768
config_from_training = config_manager.config
6869

70+
# copy the to_tensor.py file from the input directory to the output directory
71+
to_tensor_src = input_directory / "to_tensor.py"
72+
to_tensor_dst = output_dir / "to_tensor.py"
73+
if to_tensor_src.exists():
74+
shutil.copy(to_tensor_src, to_tensor_dst)
75+
6976
# Use the config file to locate and assemble the trained weight file path
7077
weights_file_path = input_directory / config_from_training["train"]["weights_filename"]
7178

tests/hyrax/test_nan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_nan_handling_off_returns_input(loopback_hyrax_nan):
133133
def to_tensor(data_dict):
134134
data = data_dict.get("data", {})
135135
if "image" in data and "label" in data:
136-
image = tensor(data["image"])
136+
image = data["image"]
137137
label = data["label"]
138138
return (image, label)
139139

0 commit comments

Comments
 (0)