@@ -78,10 +78,21 @@ def export_to_onnx(model, sample, config, ctx):
7878
7979
8080def _export_pytorch_to_onnx (model , sample , output_filepath , opset_version ):
81- """Specific implementation to convert PyTorch model to ONNX format. This
82- function will also:
83- - Run `sample` through the model before converting the model to ONNX
84- - Convert `sample` to a numpy array
81+ """Specific implementation to convert PyTorch model to ONNX format. This uses
82+ the older (torch<2.9) export capabilities. And only supports up the opset
83+ version 20.
84+
85+ Parameters
86+ ----------
87+ model : torch.nn.Module
88+ The PyTorch model to be converted to ONNX format.
89+ sample : NumPy array or list of NumPy arrays
90+ A sample of input data to the model. This is used to trace the model
91+ during the export process.
92+ output_filepath : Path
93+ The file path where the ONNX model will be saved.
94+ opset_version : int
95+ The ONNX opset version to use for the export.
8596 """
8697
8798 # deferred import to reduce start up time
@@ -106,19 +117,23 @@ def _export_pytorch_to_onnx(model, sample, output_filepath, opset_version):
106117 input_names = []
107118 dynamic_axes = {}
108119
109- # ! Currently `sample` is either a tuple or bare numpy array. But after it
110- # ! goes through `default_convert` above, it becomes a list of Tensors.
120+ # torch_sample is returned from default_convert as either a single Tensor or
121+ # a list of Tensors.
111122 if isinstance (torch_sample , list ):
112123 for i in range (len (torch_sample )):
113124 # For supervised models, the label or target should be empty
114- # so we will not include those in the input names.
125+ # so we will not include those in the input names. Any labels should
126+ # be the last element in the list.
115127 if len (torch_sample [i ]):
116128 input_names .append (f"input_{ i } " )
117129 dynamic_axes [f"input_{ i } " ] = {0 : "batch_size" }
118130 else :
119131 input_names .append ("input" )
120132 dynamic_axes ["input" ] = {0 : "batch_size" }
121133
134+ # Output is assumed to always have a dynamic batch size.
135+ dynamic_axes ["output" ] = {0 : "batch_size" }
136+
122137 # export the model to ONNX format
123138 export (
124139 model ,
@@ -128,11 +143,73 @@ def _export_pytorch_to_onnx(model, sample, output_filepath, opset_version):
128143 input_names = input_names ,
129144 output_names = ["output" ],
130145 dynamic_axes = dynamic_axes ,
146+ dynamo = False , # newer versions of torch will use dynamo by default
131147 )
132148
133149 # Make sure that the output is on the CPU
134150 if sample_out .device .type != "cpu" :
135151 sample_out = sample_out .to ("cpu" )
136152
137153 # Return the output of the model as numpy array
138- return sample_out .numpy ()
154+ return sample_out .detach ().numpy ()
155+
156+
157+ def _export_pytorch_to_onnx_v2 (model , sample , output_filepath , opset_version ):
158+ """Currently unused.
159+ Specific implementation to convert PyTorch model to ONNX format using
160+ torch Dynamo export capabilities.
161+
162+ Parameters
163+ ----------
164+ model : torch.nn.Module
165+ The PyTorch model to be converted to ONNX format.
166+ sample : NumPy array or list of NumPy arrays
167+ A sample of input data to the model. This is used to trace the model
168+ during the export process.
169+ output_filepath : Path
170+ The file path where the ONNX model will be saved.
171+ opset_version : int
172+ The ONNX opset version to use for the export.
173+ """
174+
175+ # deferred import to reduce start up time
176+ import torch
177+ from torch .onnx import export
178+ from torch .utils .data .dataloader import default_convert
179+
180+ # set model in eval mode and move it to the CPU to prep for export to ONNX.
181+ model .train (False )
182+ model .to ("cpu" )
183+
184+ # set the default device to CPU and convert the sample to torch Tensors
185+ torch .set_default_device ("cpu" )
186+ torch_sample = default_convert (sample )
187+
188+ # Run a single sample through the model. We'll check this against the output
189+ # from the ONNX version to make sure it's the same, i.e. `np.assert_allclose`.
190+ sample_out = model (torch_sample )
191+ # Make sure that the output is on the CPU, detached, and as a numpy array
192+ sample_out = sample_out .to ("cpu" ).detach ().numpy ()
193+
194+ dynamic_shapes = []
195+ batch = torch .export .Dim ("batch" )
196+
197+ # TODO: This should be built dynamically based on the structure of torch_sample.
198+ dynamic_shapes = [[{0 : batch }, {0 : batch }, {}]]
199+
200+ export (
201+ model ,
202+ (torch_sample ,), # exporter expects a tuple of inputs for `forward`
203+ output_filepath ,
204+ opset_version = opset_version ,
205+ dynamo = True ,
206+ dynamic_shapes = dynamic_shapes ,
207+ verbose = True ,
208+ report = True ,
209+ dump_exported_program = True ,
210+ artifacts_dir = output_filepath .parent ,
211+ input_names = ["input" ],
212+ output_names = ["output" ],
213+ )
214+
215+ return sample_out
0 commit comments