forked from huggingface/optimum
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconvert.py
More file actions
1267 lines (1089 loc) · 57.5 KB
/
convert.py
File metadata and controls
1267 lines (1089 loc) · 57.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ONNX model check and export functions."""
import copy
import gc
import multiprocessing as mp
import os
import traceback
from inspect import signature
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import onnx
from transformers.generation import GenerationMixin
from transformers.modeling_utils import get_parameter_dtype
from transformers.utils import is_tf_available, is_torch_available
from ...onnx.graph_transformations import check_and_save_model
from ...onnx.utils import _get_onnx_external_constants, _get_onnx_external_data_tensors, check_model_uses_external_data
from ...utils import (
DEFAULT_DUMMY_SHAPES,
ONNX_WEIGHTS_NAME,
TORCH_MINIMUM_VERSION,
is_diffusers_available,
is_onnxslim_available,
is_torch_onnx_support_available,
is_transformers_version,
logging,
require_numpy_strictly_lower,
)
from ...utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST
from ...utils.save_utils import maybe_save_preprocessors
from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from ..utils import check_dummy_inputs_are_allowed
from .base import OnnxConfig
from .constants import UNPICKABLE_ARCHS
from .model_configs import SpeechT5OnnxConfig
from .utils import (
MODEL_TYPES_REQUIRING_POSITION_IDS,
PickableInferenceSession,
_get_submodels_and_onnx_configs,
recursive_to_device,
)
# TODO : moved back onnx imports applied in https://github.com/huggingface/optimum/pull/2114/files after refactorization
if is_torch_available():
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
if is_diffusers_available():
from diffusers import DiffusionPipeline, ModelMixin
if is_tf_available():
from transformers.modeling_tf_utils import TFPreTrainedModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class DynamicAxisNameError(ValueError):
pass
def validate_models_outputs(
models_and_onnx_configs: Dict[
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
],
onnx_named_outputs: List[List[str]],
output_dir: Path,
atol: Optional[float] = None,
onnx_files_subpaths: Optional[List[str]] = None,
input_shapes: Optional[Dict] = None,
device: str = "cpu",
use_subprocess: Optional[bool] = True,
model_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Validates the export of several models, by checking that the outputs from both the reference and the exported model match.
The following method validates the ONNX models exported using the `export_models` method.
Args:
models_and_onnx_configs (`Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `OnnxConfig`]]):
A dictionnary containing the models to validate and their corresponding onnx configs.
onnx_named_outputs (`List[List[str]]`):
The names of the outputs to check.
output_dir (`Path`):
Output directory where the exported ONNX models are stored.
atol (`Optional[float]`, defaults to `None`):
The absolute tolerance in terms of outputs difference between the reference and the exported model.
onnx_files_subpaths (`Optional[List[str]]`, defaults to `None`):
The relative paths from `output_dir` to the ONNX files to do validation on. The order must be the same as the order of submodels
in the ordered dict `models_and_onnx_configs`. If None, will use the keys from the `models_and_onnx_configs` as names.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes to validate the ONNX model on.
device (`str`, defaults to `"cpu"`):
The device on which the ONNX models will be validated. Either `cpu` or `cuda`. Validation on a CUDA device is supported only for PyTorch.
use_subprocess (`Optional[bool]`, defaults to `True`):
Launch validation of each exported model in a subprocess.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export and validation.
Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
"""
if len(onnx_named_outputs) != len(models_and_onnx_configs.keys()):
raise ValueError(
f"Invalid number of ONNX named outputs. Required {len(models_and_onnx_configs.keys())}, Provided {len(onnx_named_outputs)}"
)
if onnx_files_subpaths is not None and len(onnx_files_subpaths) != len(models_and_onnx_configs):
raise ValueError(
f"Provided custom names {onnx_files_subpaths} for the validation of {len(models_and_onnx_configs)} models. Please provide the same number of ONNX file names as models to export."
)
if use_subprocess:
logger.info("Validating models in subprocesses...")
exceptions = [] # run all validations before raising
for i, model_name in enumerate(models_and_onnx_configs.keys()):
submodel, sub_onnx_config = models_and_onnx_configs[model_name]
onnx_model_path = (
output_dir.joinpath(onnx_files_subpaths[i])
if onnx_files_subpaths is not None
else output_dir.joinpath(model_name + ".onnx")
)
try:
# Model validation is done in subprocesses, as ONNX Runtime has the bad habit of
# not releasing memory once an InferenceSession is initialized.
# Reference: https://github.com/huggingface/optimum/pull/1115
validate_model_outputs(
config=sub_onnx_config,
reference_model=submodel,
onnx_model=onnx_model_path,
onnx_named_outputs=onnx_named_outputs[i],
atol=atol,
input_shapes=input_shapes,
device=device,
use_subprocess=use_subprocess,
model_kwargs=model_kwargs,
)
except Exception as e:
exceptions.append((onnx_model_path, e))
if len(exceptions) != 0:
for i, exception in enumerate(exceptions[:-1]):
logger.error(f"Validation for the model {exception[0].as_posix()} raised: {exception[1]}")
raise exceptions[-1][1]
def validate_model_outputs(
config: OnnxConfig,
reference_model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
onnx_model: Path,
onnx_named_outputs: List[str],
atol: Optional[float] = None,
input_shapes: Optional[Dict] = None,
device: str = "cpu",
use_subprocess: Optional[bool] = True,
model_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Validates the export by checking that the outputs from both the reference and the exported model match.
Args:
config ([`~OnnxConfig`]:
The configuration used to export the model.
reference_model ([`~PreTrainedModel`] or [`~TFPreTrainedModel`]):
The model used for the export.
onnx_model (`Path`):
The path to the exported model.
onnx_named_outputs (`List[str]`):
The names of the outputs to check.
atol (`Optional[float]`, defaults to `None`):
The absolute tolerance in terms of outputs difference between the reference and the exported model.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes to validate the ONNX model on.
device (`str`, defaults to `"cpu"`):
The device on which the ONNX model will be validated. Either `cpu` or `cuda`. Validation on a CUDA device is supported only for PyTorch.
use_subprocess (`Optional[bool]`, defaults to `True`):
Launch validation of each exported model in a subprocess.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export and validation.
Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
"""
if use_subprocess:
# InferenceSession do not support the fork start method with some EP: https://github.com/microsoft/onnxruntime/issues/7846
mp.set_start_method("spawn", force=True)
io_process = ValidationProcess(
config, reference_model, onnx_model, onnx_named_outputs, atol, input_shapes, device, model_kwargs
)
io_process.start()
io_process.join()
if io_process.exception:
error, traceback = io_process.exception
raise error
else:
_run_validation(
config,
reference_model,
onnx_model,
onnx_named_outputs,
atol,
input_shapes,
device,
model_kwargs=model_kwargs,
)
def _run_validation(
config: OnnxConfig,
reference_model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
onnx_model: Path,
onnx_named_outputs: List[str],
atol: Optional[float] = None,
input_shapes: Optional[Dict] = None,
device: str = "cpu",
model_kwargs: Optional[Dict[str, Any]] = None,
):
from onnxruntime import GraphOptimizationLevel, SessionOptions
model_kwargs = model_kwargs if model_kwargs is not None else {}
logger.info(f"\nValidating ONNX model {onnx_model.as_posix()}...")
if atol is None:
atol = config.ATOL_FOR_VALIDATION
if "diffusers" in str(reference_model.__class__) and not is_diffusers_available():
raise ImportError("The pip package `diffusers` is required to validate diffusion ONNX models.")
framework = "pt" if is_torch_available() and isinstance(reference_model, nn.Module) else "tf"
if input_shapes is None:
input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES
reference_model_inputs = config.generate_dummy_inputs(framework=framework, **input_shapes)
# Create ONNX Runtime session
session_options = SessionOptions()
# We could well set ORT_DISABLE_ALL here, but it makes CUDA export with O4 of gpt_neo fail
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
if device.startswith("cuda"):
provider = "CUDAExecutionProvider"
else:
provider = "CPUExecutionProvider"
session = PickableInferenceSession(onnx_model.as_posix(), sess_options=session_options, providers=[provider])
# Sometimes the exported model can have more outputs than what is specified in the ONNX config because the original
# PyTorch model has more outputs that were forgotten in the config, so we check for that.
all_onnx_outputs = {output.name for output in session.get_outputs()}
config_outputs = set(config.outputs)
if all_onnx_outputs != config_outputs:
if len(all_onnx_outputs) > len(config_outputs):
diff = all_onnx_outputs - config_outputs
else:
diff = config_outputs - all_onnx_outputs
raise OutputMatchError(
"The exported ONNX model does not have the exact same outputs as what is provided in "
f"{config.__class__.__name__}. Difference: {', '.join(diff)}"
)
# Sometimes the exported model can have axes that are inferred as dynamic axes but were not specified as such in
# the ONNX Config: it was either an error on the config side, or an error on the ONNX side inferring a dynamic axis
# that is actually static.
# The `OnnxConfig.fix_dynamic_axes` method should fix that at export time, but it is still worth checking here.
all_config_dynamic_axes_names = set()
for input_ in config.inputs.values():
all_config_dynamic_axes_names |= set(input_.values())
for output in config.outputs.values():
all_config_dynamic_axes_names |= set(output.values())
for node in session.get_outputs():
for idx, axis in enumerate(node.shape):
if isinstance(axis, str) and axis not in all_config_dynamic_axes_names:
raise DynamicAxisNameError(
f"The axis {idx} of input / output node called {node.name} has an unknown name: {axis}"
)
# Compute outputs from the reference model
if is_torch_available() and isinstance(reference_model, nn.Module):
reference_model.to(device)
for key, value in reference_model_inputs.items():
reference_model_inputs[key] = recursive_to_device(value=value, device=device)
# Some models may modify in place the inputs, hence the copy.
copy_reference_model_inputs = copy.deepcopy(reference_model_inputs)
copy_reference_model_inputs = config.rename_ambiguous_inputs(copy_reference_model_inputs)
with config.patch_model_for_export(reference_model, model_kwargs=model_kwargs):
if is_torch_available() and isinstance(reference_model, nn.Module):
with torch.inference_mode():
ref_outputs = reference_model(**copy_reference_model_inputs)
else:
ref_outputs = reference_model(**copy_reference_model_inputs)
ref_outputs_dict = {}
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
for name, value in ref_outputs.items():
# Overwriting the output name as "present" since it is the name used for the ONNX outputs
# ("past_key_values" being taken for the ONNX inputs)
if name == "past_key_values":
name = "present"
if isinstance(value, (list, tuple)):
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
value = config.flatten_output_collection_property(onnx_output_name, value)
ref_outputs_dict.update(value)
else:
ref_outputs_dict[name] = value
onnx_input_names = [inp.name for inp in session.get_inputs()]
# Possibly edit the input for the onnxruntime.InferenceSession, this is for example the case for merged
# models where the input `use_cache_branch` is added
reference_ort_inputs = config.generate_dummy_inputs_for_validation(
reference_model_inputs, onnx_input_names=onnx_input_names
)
# We flatten potential collection of inputs (i.e. past_keys)
onnx_inputs = {}
for name, value in reference_ort_inputs.items():
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()})
elif isinstance(value, dict):
onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.cpu().numpy()
# Compute outputs from the ONNX model
onnx_outputs = session.run(onnx_named_outputs, onnx_inputs)
# Modify the ONNX output names to match the reference model output names
onnx_to_torch = {v: k for k, v in config.torch_to_onnx_output_map.items()}
onnx_named_outputs = [onnx_to_torch.get(k, k) for k in onnx_named_outputs]
# Check we have a subset of the keys into onnx_outputs against ref_outputs
ref_outputs_set, onnx_outputs_set = set(ref_outputs_dict.keys()), set(onnx_named_outputs)
if not onnx_outputs_set.issubset(ref_outputs_set):
raise OutputMatchError(
"ONNX model output names do not match reference model output names.\n"
f"Reference model output names: {ref_outputs_set}\n"
f"ONNX model output names: {onnx_outputs_set}\n"
f"Difference: {onnx_outputs_set.difference(ref_outputs_set)}"
)
else:
onnx_output_names = ", ".join(onnx_outputs_set)
logger.info(f"\t-[✓] ONNX model output names match reference model ({onnx_output_names})")
if "diffusers" in str(reference_model.__class__) and not is_diffusers_available():
raise ImportError("The pip package `diffusers` is required to validate diffusion ONNX models.")
# Check the shape and values match
shape_failures = []
value_failures = []
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
if is_torch_available() and isinstance(reference_model, nn.Module):
ref_value = ref_outputs_dict[name].detach().cpu().numpy()
else:
ref_value = ref_outputs_dict[name].cpu().numpy()
logger.info(f'\t- Validating ONNX Model output "{name}":')
# Shape
if not ort_value.shape == ref_value.shape:
logger.error(f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}")
shape_failures.append((name, ref_value.shape, ort_value.shape))
else:
logger.info(f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
# Values
try:
if not np.allclose(ref_value, ort_value, atol=atol):
max_diff = np.amax(np.abs(ref_value - ort_value))
logger.error(f"\t\t-[x] values not close enough, max diff: {max_diff} (atol: {atol})")
value_failures.append((name, max_diff))
else:
logger.info(f"\t\t-[✓] all values close (atol: {atol})")
except Exception:
# If shapes do not match, it is possible that the np.allclose call fails, since we raise the proper issue
# right after, we do not do anything here.
pass
if shape_failures:
msg = "\n".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (ONNX)" for t in shape_failures)
raise ShapeError(f"Output shapes do not match between reference model and ONNX exported model:\n{msg}")
if value_failures:
msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures)
atol_msg = f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}"
if isinstance(config, SpeechT5OnnxConfig):
atol_msg += "\nIMPORTANT NOTE: SpeechT5 uses a dropout at inference and the output validation of ONNX Runtime inference vs PyTorch is expected to fail. Reference: https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L727"
raise AtolError(atol_msg)
class ValidationProcess(mp.Process):
def __init__(
self,
config: OnnxConfig,
reference_model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
onnx_model: Path,
onnx_named_outputs: List[str],
atol: Optional[float] = None,
input_shapes: Optional[Dict] = None,
device: str = "cpu",
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__()
self._pconn, self._cconn = mp.Pipe()
self._exception = None
self.config = config
self.reference_model = reference_model
self.onnx_model = onnx_model
self.onnx_named_outputs = onnx_named_outputs
self.atol = atol
self.input_shapes = input_shapes
self.device = device
self.model_kwargs = model_kwargs
def run(self):
try:
_run_validation(
config=self.config,
reference_model=self.reference_model,
onnx_model=self.onnx_model,
onnx_named_outputs=self.onnx_named_outputs,
atol=self.atol,
input_shapes=self.input_shapes,
device=self.device,
model_kwargs=self.model_kwargs,
)
except Exception as e:
tb = traceback.format_exc()
self._cconn.send((e, tb))
return
@property
def exception(self):
if self._pconn.poll():
self._exception = self._pconn.recv()
return self._exception
def export_pytorch(
model: Union["PreTrainedModel", "ModelMixin"],
config: OnnxConfig,
opset: int,
output: Path,
device: str = "cpu",
input_shapes: Optional[Dict] = None,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an ONNX Intermediate Representation.
Args:
model ([`PreTrainedModel`]):
The model to export.
config ([`~exporters.onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Path to save the exported ONNX file to.
device (`str`, defaults to `"cpu"`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
in case, for example, the model inputs/outputs are changed (for example, if
`model_kwargs={"output_attentions": True}` is passed).
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from
the ONNX configuration.
"""
from torch.onnx import export as onnx_export
from torch.utils._pytree import tree_map
logger.info(f"Using framework PyTorch: {torch.__version__}")
FORCE_ONNX_EXTERNAL_DATA = os.getenv("FORCE_ONNX_EXTERNAL_DATA", "0") == "1"
model_kwargs = model_kwargs or {}
# num_logits_to_keep was added in transformers 4.45 and isn't added as inputs when exporting the model
if is_transformers_version(">=", "4.45"):
logits_to_keep_name = "logits_to_keep" if is_transformers_version(">=", "4.49") else "num_logits_to_keep"
if logits_to_keep_name in signature(model.forward).parameters.keys():
model_kwargs[logits_to_keep_name] = 0
with torch.no_grad():
model.config.return_dict = True
model = model.eval()
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
if input_shapes is None:
input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES
# Check that inputs match, and order them properly
dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes)
device = torch.device(device)
def remap(value):
if isinstance(value, torch.Tensor):
value = value.to(device)
return value
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
dummy_inputs = tree_map(remap, dummy_inputs)
dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())
if no_dynamic_axes:
dynamix_axes = None
else:
dynamix_axes = dict(chain(inputs.items(), config.outputs.items()))
# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamix_axes,
do_constant_folding=do_constant_folding,
opset_version=opset,
)
# check if external data was exported
onnx_model = onnx.load(str(output), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)
if model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA:
tensors_paths = _get_onnx_external_data_tensors(onnx_model)
constant_paths = _get_onnx_external_constants(onnx_model)
logger.info("Saving external data to one file...")
# try free model memory
del model
del onnx_model
gc.collect()
if device.type == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()
# this will probably be too memory heavy for large models
onnx_model = onnx.load(str(output), load_external_data=True)
onnx.save(
onnx_model,
str(output),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=output.name + "_data",
size_threshold=1024 if not FORCE_ONNX_EXTERNAL_DATA else 100,
convert_attribute=True,
)
# delete previous external data
for tensor in tensors_paths:
os.remove(output.parent / tensor)
for tensor in constant_paths:
if os.path.isfile(output.parent / tensor):
os.remove(output.parent / tensor)
return input_names, output_names
@require_numpy_strictly_lower("1.24.0", "The Tensorflow ONNX export only supports numpy<1.24.0.")
def export_tensorflow(
model: "TFPreTrainedModel",
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
"""
Exports a TensorFlow model to an ONNX Intermediate Representation.
Args:
model ([`TFPreTrainedModel`]):
The model to export.
config ([`~exporters.onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
device (`Optional[str]`, defaults to `"cpu"`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from
the ONNX configuration.
"""
# This is needed to import onnx and tf2onnx because onnx is also the name of the current directory.
import sys
import onnx
import tensorflow as tf
import tf2onnx
sys_path_backup = sys.path
sys.path.pop(0)
sys.path = sys_path_backup
logger.info(f"Using framework TensorFlow: {tf.__version__}")
model.config.return_dict = True
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
# Ensure inputs match
dummy_inputs = config.generate_dummy_inputs(framework="tf")
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())
input_signature = []
for key, tensor in dummy_inputs.items():
shape = [tensor.shape[i] for i in range(tensor.ndim)]
for idx, _ in config.inputs[key].items():
shape[idx] = None
input_signature.append(tf.TensorSpec(shape, dtype=tensor.dtype, name=key))
with config.patch_model_for_export(model):
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
onnx.save(
onnx_model,
output.as_posix(),
convert_attribute=True,
)
return input_names, output_names
def export_models(
models_and_onnx_configs: Dict[
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
],
output_dir: Path,
opset: Optional[int] = None,
output_names: Optional[List[str]] = None,
device: str = "cpu",
input_shapes: Optional[Dict] = None,
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation.
The following method exports the encoder and decoder components of the model as separate
ONNX files.
Args:
models_and_onnx_configs (`Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`, `ModelMixin`], `OnnxConfig`]]):
A dictionnary containing the models to export and their corresponding onnx configs.
output_dir (`Path`):
Output directory to store the exported ONNX models.
opset (`Optional[int]`, defaults to `None`):
The version of the ONNX operator set to use.
output_names (`Optional[List[str]]`, defaults to `None`):
The names to use for the exported ONNX files. The order must be the same as the order of submodels in the ordered dict `models_and_onnx_configs`.
If None, will use the keys from `models_and_onnx_configs` as names.
device (`str`, defaults to `"cpu"`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
disable_dynamic_axes_fix (`Optional[bool]`, defaults to `False`):
Whether to disable the default dynamic axes fixing.
dtype (`Optional[str]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
in case, for example, the model inputs/outputs are changed (for example, if
`model_kwargs={"output_attentions": True}` is passed).
Returns:
`Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named
outputs from the ONNX configuration.
"""
outputs = []
if output_names is not None and len(output_names) != len(models_and_onnx_configs):
raise ValueError(
f"Provided custom names {output_names} for the export of {len(models_and_onnx_configs)} models. Please provide the same number of names as models to export."
)
for i, model_name in enumerate(models_and_onnx_configs.keys()):
submodel, sub_onnx_config = models_and_onnx_configs[model_name]
output_name = output_names[i] if output_names is not None else Path(model_name + ".onnx")
output_path = output_dir / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(
f"\n***** Exporting submodel {i + 1}/{len(models_and_onnx_configs)}: {submodel.__class__.__name__} *****"
)
outputs.append(
export(
model=submodel,
config=sub_onnx_config,
output=output_path,
opset=opset,
device=device,
input_shapes=input_shapes,
disable_dynamic_axes_fix=disable_dynamic_axes_fix,
dtype=dtype,
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
model_kwargs=model_kwargs,
)
)
outputs = list(map(list, zip(*outputs)))
return outputs
def export(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
config: OnnxConfig,
output: Path,
opset: Optional[int] = None,
device: str = "cpu",
input_shapes: Optional[Dict] = None,
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an ONNX Intermediate Representation.
Args:
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
config ([`~exporters.onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
output (`Path`):
Directory to store the exported ONNX model.
opset (`Optional[int]`, defaults to `None`):
The version of the ONNX operator set to use.
device (`Optional[str]`, defaults to `"cpu"`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
disable_dynamic_axes_fix (`Optional[bool]`, defaults to `False`):
Whether to disable the default dynamic axes fixing.
dtype (`Optional[str]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
in case, for example, the model inputs/outputs are changed (for example, if
`model_kwargs={"output_attentions": True}` is passed).
Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from
the ONNX configuration.
"""
if not (is_torch_available() or is_tf_available()):
raise ImportError(
"Cannot convert because neither PyTorch nor TensorFlow are installed. "
"Please install torch or tensorflow first."
)
output.parent.mkdir(parents=True, exist_ok=True)
export_output = None
if opset is None:
opset = config.DEFAULT_ONNX_OPSET
if "diffusers" in str(model.__class__) and not is_diffusers_available():
raise ImportError("The pip package `diffusers` is required to export diffusion models to ONNX.")
if not config.is_transformers_support_available:
import transformers
raise MinimumVersionError(
f"The current version of Transformers does not allow for the export of the model. Minimum required is "
f"{config.MIN_TRANSFORMERS_VERSION}, got: {transformers.__version__}"
)
if is_torch_available() and isinstance(model, nn.Module):
from ...utils.import_utils import _torch_version
if not is_torch_onnx_support_available():
raise MinimumVersionError(
f"Unsupported PyTorch version, minimum required is {TORCH_MINIMUM_VERSION}, got: {_torch_version}"
)
if not config.is_torch_support_available:
raise MinimumVersionError(
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION}, got: {_torch_version}"
)
export_output = export_pytorch(
model,
config,
opset,
output,
device=device,
input_shapes=input_shapes,
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
model_kwargs=model_kwargs,
)
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
if model_kwargs is not None:
raise NotImplementedError(
"The argument `model_kwargs` is used only for PyTorch ONNX export, and unavailable for the Tensorflow export."
)
if device == "cuda":
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
if input_shapes is not None:
logger.info("`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored.")
export_output = export_tensorflow(model, config, opset, output)
else:
raise RuntimeError(
"You either provided a PyTorch model with only TensorFlow installed, or a TensorFlow model with only PyTorch installed."
)
if not disable_dynamic_axes_fix:
config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype)
return export_output
def onnx_export_from_model(
model: Union["PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"],
output: Union[str, Path],
opset: Optional[int] = None,
optimize: Optional[str] = None,
monolith: bool = False,
no_post_process: bool = False,
atol: Optional[float] = None,
do_validation: bool = True,
model_kwargs: Optional[Dict[str, Any]] = None,
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
fn_get_submodels: Optional[Callable] = None,
_variant: str = "default",
legacy: bool = False,
preprocessors: List = None,
device: str = "cpu",
no_dynamic_axes: bool = False,
task: Optional[str] = None,
use_subprocess: bool = False,
do_constant_folding: bool = True,
slim: bool = False,
**kwargs_shapes,
):
"""
Full-suite ONNX export function, exporting **from a pre-loaded PyTorch or Tensorflow model**. This function is especially useful in case one needs to do modifications on the model, as overriding a forward call, before exporting to ONNX.
Args:
> Required parameters
model (`Union["PreTrainedModel", "TFPreTrainedModel"]`):
PyTorch or TensorFlow model to export to ONNX.
output (`Union[str, Path]`):
Path indicating the directory where to store the generated ONNX model.
> Optional parameters
task (`Optional[str]`, defaults to `None`):
The task to export the model for. If not specified, the task will be auto-inferred based on the model.
opset (`Optional[int]`, defaults to `None`):
If specified, ONNX opset version to export the model with. Otherwise, the default opset for the given model architecture
will be used.
device (`str`, defaults to `"cpu"`):
The device to use to do the export. Defaults to "cpu".
optimize (`Optional[str]`, defaults to `None`):
Allows to run ONNX Runtime optimizations directly during the export. Some of these optimizations are specific to
ONNX Runtime, and the resulting ONNX will not be usable with other runtime as OpenVINO or TensorRT.
Available options: `"O1", "O2", "O3", "O4"`. Reference: [`~optimum.onnxruntime.AutoOptimizationConfig`]
monolith (`bool`, defaults to `False`):
Forces to export the model as a single ONNX file.
no_post_process (`bool`, defaults to `False`):
Allows to disable any post-processing done by default on the exported ONNX models.
atol (`Optional[float]`, defaults to `None`):
If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_configs` argument
in case, for example, the model inputs/outputs are changed (for example, if
`model_kwargs={"output_attentions": True}` is passed).
custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`):
Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model).
fn_get_submodels (`Optional[Callable]`, defaults to `None`):
Experimental usage: Override the default submodels that are used at the export. This is
especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success.
use_subprocess (`bool`, defaults to `False`):
Do the ONNX exported model validation in subprocesses. This is especially useful when
exporting on CUDA device, where ORT does not release memory at inference session
destruction. When set to `True`, the `main_export` call should be guarded in
`if __name__ == "__main__":` block.
_variant (`str`, defaults to `default`):
Specify the variant of the ONNX export to use.
legacy (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
slim (bool, defaults to `False`):
Use onnxslim to optimize the ONNX model.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Example usage:
```python
>>> from transformers import AutoModelForCausalLM
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> # At this point, we could override some submodules, forward methods, weights, etc. from the model.
>>> onnx_export_from_model(model, output="gpt2_onnx/")
```
"""
TasksManager.standardize_model_attributes(model)
if hasattr(model.config, "export_model_type"):
model_type = model.config.export_model_type.replace("_", "-")
else:
model_type = model.config.model_type.replace("_", "-")
library_name = TasksManager.infer_library_from_model(model)