Skip to content

Commit d1af494

Browse files
feat: add onnxslim (#2258)
* feat: add onnxslim * fix style and rename simplify to slim * add onnxslim tests * fix format * add slim args for main_export and make slim true in test_export_cli for tests * fix format * add is_onnxslim_available func and add onnxslim to test dependency * refactor format and pin onnxslim to 0.1.53 * Update optimum/exporters/onnx/convert.py Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> * Update tests/exporters/onnx/test_export_cli.py Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> * use glob and refactor tests * Update optimum/exporters/onnx/convert.py * Update tests/exporters/onnx/test_export_cli.py * remove glob * fix tests error * add slim to _onnx_export * Update optimum/exporters/onnx/convert.py --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
1 parent 92c178d commit d1af494

8 files changed

Lines changed: 61 additions & 0 deletions

File tree

optimum/commands/export/onnx.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ def parse_args_onnx(parser):
169169
action="store_true",
170170
help="PyTorch-only argument. Disables PyTorch ONNX export constant folding.",
171171
)
172+
optional_group.add_argument(
173+
"--slim",
174+
action="store_true",
175+
help="Enables onnxslim optimization.",
176+
)
172177

173178
input_group = parser.add_argument_group(
174179
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
@@ -286,5 +291,6 @@ def run(self):
286291
no_dynamic_axes=self.args.no_dynamic_axes,
287292
model_kwargs=self.args.model_kwargs,
288293
do_constant_folding=not self.args.no_constant_folding,
294+
slim=self.args.slim,
289295
**input_shapes,
290296
)

optimum/exporters/onnx/__main__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def main_export(
7878
legacy: bool = False,
7979
no_dynamic_axes: bool = False,
8080
do_constant_folding: bool = True,
81+
slim: bool = False,
8182
**kwargs_shapes,
8283
):
8384
"""
@@ -166,6 +167,8 @@ def main_export(
166167
If True, disables the use of dynamic axes during ONNX export.
167168
do_constant_folding (bool, defaults to `True`):
168169
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
170+
slim (bool, defaults to `False`):
171+
PyTorch-specific argument. If `True`, use onnxslim to optimize the ONNX model.
169172
**kwargs_shapes (`Dict`):
170173
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
171174
@@ -391,6 +394,7 @@ def main_export(
391394
task=task,
392395
use_subprocess=use_subprocess,
393396
do_constant_folding=do_constant_folding,
397+
slim=slim,
394398
**kwargs_shapes,
395399
)
396400

optimum/exporters/onnx/convert.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@
3030
from transformers.modeling_utils import get_parameter_dtype
3131
from transformers.utils import is_tf_available, is_torch_available
3232

33+
from ...onnx.graph_transformations import check_and_save_model
3334
from ...onnx.utils import _get_onnx_external_constants, _get_onnx_external_data_tensors, check_model_uses_external_data
3435
from ...utils import (
3536
DEFAULT_DUMMY_SHAPES,
3637
ONNX_WEIGHTS_NAME,
3738
TORCH_MINIMUM_VERSION,
3839
is_diffusers_available,
40+
is_onnxslim_available,
3941
is_torch_onnx_support_available,
4042
is_transformers_version,
4143
logging,
@@ -917,6 +919,7 @@ def onnx_export_from_model(
917919
task: Optional[str] = None,
918920
use_subprocess: bool = False,
919921
do_constant_folding: bool = True,
922+
slim: bool = False,
920923
**kwargs_shapes,
921924
):
922925
"""
@@ -972,6 +975,8 @@ def onnx_export_from_model(
972975
If True, disables the use of dynamic axes during ONNX export.
973976
do_constant_folding (bool, defaults to `True`):
974977
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
978+
slim (bool, defaults to `False`):
979+
Use onnxslim to optimize the ONNX model.
975980
**kwargs_shapes (`Dict`):
976981
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
977982
@@ -1196,6 +1201,17 @@ def onnx_export_from_model(
11961201
optimization_config.disable_shape_inference = True
11971202
optimizer.optimize(save_dir=output, optimization_config=optimization_config, file_suffix="")
11981203

1204+
if slim:
1205+
if not is_onnxslim_available():
1206+
raise ImportError("The pip package `onnxslim` is required to optimize onnx models.")
1207+
1208+
from onnxslim import slim
1209+
1210+
for subpath in onnx_files_subpaths:
1211+
file_path = os.path.join(output, subpath)
1212+
slimmed_model = slim(file_path)
1213+
check_and_save_model(slimmed_model, file_path)
1214+
11991215
# Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any
12001216
# TODO: treating diffusion separately is quite ugly
12011217
if not no_post_process and library_name != "diffusers":

optimum/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
is_gptqmodel_available,
4242
is_onnx_available,
4343
is_onnxruntime_available,
44+
is_onnxslim_available,
4445
is_pydantic_available,
4546
is_sentence_transformers_available,
4647
is_tf_available,

optimum/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _is_package_available(
134134
"intel-tensorflow-avx512",
135135
],
136136
)
137+
_onnxslim_available = _is_package_available("onnxslim")
137138

138139
if _tf_available and version.parse(_tf_version) < version.parse("2"):
139140
logger.warning(
@@ -267,6 +268,10 @@ def is_gptqmodel_available():
267268
)
268269

269270

271+
def is_onnxslim_available():
272+
return _onnxslim_available
273+
274+
270275
@contextmanager
271276
def check_if_pytorch_greater(target_version: str, message: str):
272277
r"""

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"sentencepiece",
3939
"rjieba",
4040
"hf_xet",
41+
"onnxslim>=0.1.53",
4142
]
4243

4344
QUALITY_REQUIRE = ["black~=23.1", "ruff==0.1.5"]

tests/exporters/onnx/test_export_cli.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS,
4141
PYTORCH_DIFFUSION_MODEL,
4242
PYTORCH_EXPORT_MODELS_TINY,
43+
PYTORCH_EXPORT_MODELS_TINY_SLIM,
4344
PYTORCH_SENTENCE_TRANSFORMERS_MODEL,
4445
PYTORCH_TIMM_MODEL,
4546
PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES,
@@ -181,6 +182,7 @@ def _onnx_export(
181182
variant: str = "default",
182183
no_dynamic_axes: bool = False,
183184
model_kwargs: Optional[Dict] = None,
185+
slim: bool = False,
184186
):
185187
# We need to set this to some value to be able to test the outputs values for batch size > 1.
186188
if task == "text-classification":
@@ -203,6 +205,7 @@ def _onnx_export(
203205
no_dynamic_axes=no_dynamic_axes,
204206
pad_token_id=pad_token_id,
205207
model_kwargs=model_kwargs,
208+
slim=slim,
206209
)
207210
except MinimumVersionError as e:
208211
pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}")
@@ -730,3 +733,24 @@ def test_complex_synonyms(self):
730733
model.save_pretrained(tmpdir_in)
731734

732735
main_export(model_name_or_path=tmpdir_in, output=tmpdir_out, task="text-classification")
736+
737+
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY_SLIM, library_name="transformers"))
738+
def test_exporters_cli_pytorch_with_slim(
739+
self,
740+
test_name: str,
741+
model_type: str,
742+
model_name: str,
743+
task: str,
744+
variant: str,
745+
monolith: bool,
746+
no_post_process: bool,
747+
):
748+
self._onnx_export(
749+
model_name,
750+
task,
751+
monolith,
752+
no_post_process,
753+
slim=True,
754+
device="cpu",
755+
variant=variant,
756+
)

tests/exporters/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,7 @@
365365
"timm/ese_vovnet19b_dw.ra_in1k": ["image-classification"],
366366
}
367367
}
368+
369+
PYTORCH_EXPORT_MODELS_TINY_SLIM = {
370+
k: v for k, v in PYTORCH_EXPORT_MODELS_TINY.items() if k in ["modernbert", "llama", "t5", "whisper"]
371+
}

0 commit comments

Comments
 (0)