Skip to content

Commit 4b8ed24

Browse files
eaidovaAlexKoff88
andauthored
OV migrate model export on pytorch frontend (#397)
* switch on pytorch frontend * fixes for seq2seq * wip * cleanup * fix style * revert changes not related to pr * clear ts registry: * remove ov dev from deps * update tests * return serialize back * switch on pytorch frontend * fixes for seq2seq * wip * cleanup * fix style * revert changes not related to pr * clear ts registry: * remove ov dev from deps * return serialize back * Added weights compression * Changed NNCF version to develop * resolve dictionary as input * fix llama export in quantization flow * rebase with fixes * update prerelease package * fix onnx name issues * experiments with tests * better workaround for nncf patch torch ops and apply review comments * remove flag from_onnx * refactoring * docstrings and typehints * small fixes * add docstring to main_export * fix timm models * fix circular imports * update ov version * revert excluding deberta * update nncf on package --------- Co-authored-by: Alexander <[email protected]>
1 parent 681b946 commit 4b8ed24

File tree

16 files changed

+1026
-108
lines changed

16 files changed

+1026
-108
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .__main__ import main_export
2+
from .convert import export, export_models, export_pytorch_via_onnx
3+
4+
5+
__all__ = ["main_export", "export", "export_models"]
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import os
17+
from pathlib import Path
18+
from typing import Any, Callable, Dict, Optional, Union
19+
20+
from requests.exceptions import ConnectionError as RequestsConnectionError
21+
from transformers import AutoTokenizer
22+
from transformers.utils import is_torch_available
23+
24+
from optimum.exporters import TasksManager
25+
from optimum.exporters.onnx import __main__ as optimum_main
26+
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast
27+
from optimum.utils import DEFAULT_DUMMY_SHAPES
28+
from optimum.utils.save_utils import maybe_save_preprocessors
29+
30+
from .convert import export_models
31+
32+
33+
OV_XML_FILE_NAME = "openvino_model.xml"
34+
35+
logger = logging.getLogger(__name__)
36+
37+
if is_torch_available():
38+
import torch
39+
40+
41+
def main_export(
42+
model_name_or_path: str,
43+
output: Union[str, Path],
44+
task: str = "auto",
45+
device: str = "cpu",
46+
fp16: Optional[bool] = False,
47+
framework: Optional[str] = None,
48+
cache_dir: Optional[str] = None,
49+
trust_remote_code: bool = False,
50+
pad_token_id: Optional[int] = None,
51+
subfolder: str = "",
52+
revision: str = "main",
53+
force_download: bool = False,
54+
local_files_only: bool = False,
55+
use_auth_token: Optional[Union[bool, str]] = None,
56+
model_kwargs: Optional[Dict[str, Any]] = None,
57+
custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None,
58+
fn_get_submodels: Optional[Callable] = None,
59+
**kwargs_shapes,
60+
):
61+
"""
62+
Full-suite OpenVINO export.
63+
64+
Args:
65+
> Required parameters
66+
67+
model_name_or_path (`str`):
68+
Model ID on huggingface.co or path on disk to the model repository to export.
69+
output (`Union[str, Path]`):
70+
Path indicating the directory where to store the generated ONNX model.
71+
72+
> Optional parameters
73+
74+
task (`Optional[str]`, defaults to `None`):
75+
The task to export the model for. If not specified, the task will be auto-inferred based on the model. For decoder models,
76+
use `xxx-with-past` to export the model using past key values in the decoder.
77+
device (`str`, defaults to `"cpu"`):
78+
The device to use to do the export. Defaults to "cpu".
79+
fp16 (`Optional[bool]`, defaults to `"False"`):
80+
Use half precision during the export. PyTorch-only, requires `device="cuda"`.
81+
framework (`Optional[str]`, defaults to `None`):
82+
The framework to use for the ONNX export (`"pt"` or `"tf"`). If not provided, will attempt to automatically detect
83+
the framework for the checkpoint.
84+
cache_dir (`Optional[str]`, defaults to `None`):
85+
Path indicating where to store cache. The default Hugging Face cache path will be used by default.
86+
trust_remote_code (`bool`, defaults to `False`):
87+
Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories
88+
you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the
89+
model repository.
90+
pad_token_id (`Optional[int]`, defaults to `None`):
91+
This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it.
92+
subfolder (`str`, defaults to `""`):
93+
In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can
94+
specify the folder name here.
95+
revision (`str`, defaults to `"main"`):
96+
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id.
97+
force_download (`bool`, defaults to `False`):
98+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
99+
cached versions if they exist.
100+
local_files_only (`Optional[bool]`, defaults to `False`):
101+
Whether or not to only look at local files (i.e., do not try to download the model).
102+
use_auth_token (`Optional[str]`, defaults to `None`):
103+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
104+
when running `transformers-cli login` (stored in `~/.huggingface`).
105+
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
106+
Experimental usage: keyword arguments to pass to the model during
107+
the export. This argument should be used along the `custom_onnx_configs` argument
108+
in case, for example, the model inputs/outputs are changed (for example, if
109+
`model_kwargs={"output_attentions": True}` is passed).
110+
custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`):
111+
Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model).
112+
fn_get_submodels (`Optional[Callable]`, defaults to `None`):
113+
Experimental usage: Override the default submodels that are used at the export. This is
114+
especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success.
115+
**kwargs_shapes (`Dict`):
116+
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
117+
118+
Example usage:
119+
```python
120+
>>> from optimum.exporters.openvino import main_export
121+
122+
>>> main_export("gpt2", output="gpt2_onnx/")
123+
```
124+
"""
125+
output = Path(output)
126+
if not output.exists():
127+
output.mkdir(parents=True)
128+
129+
original_task = task
130+
task = TasksManager.map_from_synonym(task)
131+
132+
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
133+
134+
# get the shapes to be used to generate dummy inputs
135+
input_shapes = {}
136+
for input_name in DEFAULT_DUMMY_SHAPES.keys():
137+
input_shapes[input_name] = (
138+
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
139+
)
140+
141+
torch_dtype = None if fp16 is False else torch.float16
142+
143+
if task == "auto":
144+
try:
145+
task = TasksManager.infer_task_from_model(model_name_or_path)
146+
except KeyError as e:
147+
raise KeyError(
148+
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
149+
)
150+
except RequestsConnectionError as e:
151+
raise RequestsConnectionError(
152+
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
153+
)
154+
155+
model = TasksManager.get_model_from_task(
156+
task,
157+
model_name_or_path,
158+
subfolder=subfolder,
159+
revision=revision,
160+
cache_dir=cache_dir,
161+
use_auth_token=use_auth_token,
162+
local_files_only=local_files_only,
163+
force_download=force_download,
164+
trust_remote_code=trust_remote_code,
165+
framework=framework,
166+
torch_dtype=torch_dtype,
167+
device=device,
168+
)
169+
170+
custom_architecture = False
171+
is_stable_diffusion = "stable-diffusion" in task
172+
model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-")
173+
174+
if not is_stable_diffusion:
175+
if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE:
176+
raise ValueError(
177+
f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. "
178+
f"If you want to support {model_type} please propose a PR or open up an issue."
179+
)
180+
if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task(
181+
task, exporter="onnx"
182+
):
183+
custom_architecture = True
184+
185+
if custom_architecture and custom_onnx_configs is None:
186+
raise ValueError(
187+
"Trying to export a model with a custom architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models."
188+
)
189+
190+
if custom_architecture and original_task == "auto":
191+
raise ValueError(
192+
f'Automatic task detection is not supported with custom architectures. Please specify the `task` argument. Suggestion: task="{task}" (or task="{task}-with-past" if the model is decoder-based and supports KV cache)'
193+
)
194+
195+
if (
196+
not custom_architecture
197+
and not is_stable_diffusion
198+
and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx")
199+
):
200+
if original_task == "auto": # Make -with-past the default if --task was not explicitely specified
201+
task = task + "-with-past"
202+
else:
203+
logger.info(
204+
f"The task `{task}` was manually specified, and past key values will not be reused in the decoding."
205+
f" if needed, please pass `--task {task}-with-past` to export using the past key values."
206+
)
207+
208+
if original_task == "auto":
209+
synonyms_for_task = sorted(TasksManager.synonyms_for_task(task))
210+
if synonyms_for_task:
211+
synonyms_for_task = ", ".join(synonyms_for_task)
212+
possible_synonyms = f" (possible synonyms are: {synonyms_for_task})"
213+
else:
214+
possible_synonyms = ""
215+
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")
216+
onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs(
217+
model=model,
218+
task=task,
219+
monolith=False,
220+
custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {},
221+
custom_architecture=custom_architecture,
222+
fn_get_submodels=fn_get_submodels,
223+
_variant="default",
224+
)
225+
226+
if not is_stable_diffusion:
227+
needs_pad_token_id = (
228+
isinstance(onnx_config, OnnxConfigWithPast)
229+
and getattr(model.config, "pad_token_id", None) is None
230+
and task in ["text-classification"]
231+
)
232+
if needs_pad_token_id:
233+
if pad_token_id is not None:
234+
model.config.pad_token_id = pad_token_id
235+
else:
236+
try:
237+
tok = AutoTokenizer.from_pretrained(model_name_or_path)
238+
model.config.pad_token_id = tok.pad_token_id
239+
except Exception:
240+
raise ValueError(
241+
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
242+
)
243+
# Saving the model config and preprocessor as this is needed sometimes.
244+
model.config.save_pretrained(output)
245+
generation_config = getattr(model, "generation_config", None)
246+
if generation_config is not None:
247+
generation_config.save_pretrained(output)
248+
maybe_save_preprocessors(model_name_or_path, output)
249+
250+
if model.config.is_encoder_decoder and task.startswith("text-generation"):
251+
raise ValueError(
252+
f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report"
253+
f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model,"
254+
f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`."
255+
)
256+
257+
files_subpaths = None
258+
else:
259+
# save the subcomponent configuration
260+
for model_name in models_and_onnx_configs:
261+
subcomponent = models_and_onnx_configs[model_name][0]
262+
if hasattr(subcomponent, "save_config"):
263+
subcomponent.save_config(output / model_name)
264+
elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"):
265+
subcomponent.config.save_pretrained(output / model_name)
266+
267+
files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_onnx_configs]
268+
269+
# Saving the additional components needed to perform inference.
270+
model.scheduler.save_pretrained(output.joinpath("scheduler"))
271+
272+
feature_extractor = getattr(model, "feature_extractor", None)
273+
if feature_extractor is not None:
274+
feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
275+
276+
tokenizer = getattr(model, "tokenizer", None)
277+
if tokenizer is not None:
278+
tokenizer.save_pretrained(output.joinpath("tokenizer"))
279+
280+
tokenizer_2 = getattr(model, "tokenizer_2", None)
281+
if tokenizer_2 is not None:
282+
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
283+
284+
model.save_config(output)
285+
286+
export_models(
287+
models_and_onnx_configs=models_and_onnx_configs,
288+
output_dir=output,
289+
output_names=files_subpaths,
290+
input_shapes=input_shapes,
291+
device=device,
292+
model_kwargs=model_kwargs,
293+
)

0 commit comments

Comments
 (0)