|
30 | 30 | from transformers.modeling_utils import get_parameter_dtype |
31 | 31 | from transformers.utils import is_tf_available, is_torch_available |
32 | 32 |
|
| 33 | +from ...onnx.graph_transformations import check_and_save_model |
33 | 34 | from ...onnx.utils import _get_onnx_external_constants, _get_onnx_external_data_tensors, check_model_uses_external_data |
34 | 35 | from ...utils import ( |
35 | 36 | DEFAULT_DUMMY_SHAPES, |
36 | 37 | ONNX_WEIGHTS_NAME, |
37 | 38 | TORCH_MINIMUM_VERSION, |
38 | 39 | is_diffusers_available, |
| 40 | + is_onnxslim_available, |
39 | 41 | is_torch_onnx_support_available, |
40 | 42 | is_transformers_version, |
41 | 43 | logging, |
@@ -917,6 +919,7 @@ def onnx_export_from_model( |
917 | 919 | task: Optional[str] = None, |
918 | 920 | use_subprocess: bool = False, |
919 | 921 | do_constant_folding: bool = True, |
| 922 | + slim: bool = False, |
920 | 923 | **kwargs_shapes, |
921 | 924 | ): |
922 | 925 | """ |
@@ -972,6 +975,8 @@ def onnx_export_from_model( |
972 | 975 | If True, disables the use of dynamic axes during ONNX export. |
973 | 976 | do_constant_folding (bool, defaults to `True`): |
974 | 977 | PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible. |
| 978 | + slim (bool, defaults to `False`): |
| 979 | + Use onnxslim to optimize the ONNX model. |
975 | 980 | **kwargs_shapes (`Dict`): |
976 | 981 | Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. |
977 | 982 |
|
@@ -1196,6 +1201,17 @@ def onnx_export_from_model( |
1196 | 1201 | optimization_config.disable_shape_inference = True |
1197 | 1202 | optimizer.optimize(save_dir=output, optimization_config=optimization_config, file_suffix="") |
1198 | 1203 |
|
| 1204 | + if slim: |
| 1205 | + if not is_onnxslim_available(): |
| 1206 | + raise ImportError("The pip package `onnxslim` is required to optimize onnx models.") |
| 1207 | + |
| 1208 | + from onnxslim import slim |
| 1209 | + |
| 1210 | + for subpath in onnx_files_subpaths: |
| 1211 | + file_path = os.path.join(output, subpath) |
| 1212 | + slimmed_model = slim(file_path) |
| 1213 | + check_and_save_model(slimmed_model, file_path) |
| 1214 | + |
1199 | 1215 | # Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any |
1200 | 1216 | # TODO: treating diffusion separately is quite ugly |
1201 | 1217 | if not no_post_process and library_name != "diffusers": |
|
0 commit comments