Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def parse_args_onnx(parser):
action="store_true",
help="PyTorch-only argument. Disables PyTorch ONNX export constant folding.",
)
optional_group.add_argument(
"--slim",
action="store_true",
help="Enables onnxslim optimization.",
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
Expand Down Expand Up @@ -286,5 +291,6 @@ def run(self):
no_dynamic_axes=self.args.no_dynamic_axes,
model_kwargs=self.args.model_kwargs,
do_constant_folding=not self.args.no_constant_folding,
slim=self.args.slim,
**input_shapes,
)
16 changes: 16 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers.modeling_utils import get_parameter_dtype
from transformers.utils import is_tf_available, is_torch_available

from ...onnx.graph_transformations import check_and_save_model
from ...onnx.utils import _get_onnx_external_constants, _get_onnx_external_data_tensors, check_model_uses_external_data
from ...utils import (
DEFAULT_DUMMY_SHAPES,
Expand Down Expand Up @@ -917,6 +918,7 @@ def onnx_export_from_model(
task: Optional[str] = None,
use_subprocess: bool = False,
do_constant_folding: bool = True,
slim: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -972,6 +974,8 @@ def onnx_export_from_model(
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
slim (bool, defaults to `False`):
Use onnxslim to optimize the ONNX model.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

Expand Down Expand Up @@ -1196,6 +1200,18 @@ def onnx_export_from_model(
optimization_config.disable_shape_inference = True
optimizer.optimize(save_dir=output, optimization_config=optimization_config, file_suffix="")

if slim:
from onnxslim import slim

onnx_models = [os.path.join(output, x) for x in os.listdir(output) if x.endswith(".onnx")]
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated

for model in onnx_models:
try:
slimmed_model = slim(model)
check_and_save_model(slimmed_model, model)
except Exception as e:
print(f"Failed to slim {model}: {e}")
Comment thread
inisis marked this conversation as resolved.
Outdated

# Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any
# TODO: treating diffusion separately is quite ugly
if not no_post_process and library_name != "diffusers":
Expand Down
47 changes: 47 additions & 0 deletions tests/exporters/onnx/test_export_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,3 +730,50 @@ def test_complex_synonyms(self):
model.save_pretrained(tmpdir_in)

main_export(model_name_or_path=tmpdir_in, output=tmpdir_out, task="text-classification")

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY, library_name="transformers"))
@require_torch_gpu
@require_vision
@slow
@pytest.mark.gpu_test
@pytest.mark.run_slow
def test_exporters_cli_pytorch_with_slim(
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated
Comment thread
inisis marked this conversation as resolved.
Outdated
self,
test_name: str,
model_type: str,
model_name: str,
task: str,
variant: str,
monolith: bool,
no_post_process: bool,
):
# TODO: refer to https://github.com/pytorch/pytorch/issues/95377
if model_type == "yolos":
self.skipTest("Export on cuda device fails for yolos due to a bug in PyTorch")

# TODO: refer to https://github.com/pytorch/pytorch/issues/107591
if model_type == "sam":
self.skipTest("sam export on cuda is not supported due to a bug in PyTorch")

model_kwargs = None
if model_type == "speecht5":
model_kwargs = {"vocoder": "fxmarty/speecht5-hifigan-tiny"}

try:
self._onnx_export(
model_name,
task,
monolith,
no_post_process,
slim=True,
device="cuda",
variant=variant,
model_kwargs=model_kwargs,
)
except NotImplementedError as e:
if "Tried to use onnxslim for the model type" in str(
e
) or "doesn't support the graph optimization" in str(e):
self.skipTest(f"unsupported model type in onnxslim: {model_type}")
else:
raise e
Comment thread
inisis marked this conversation as resolved.
Outdated
Loading