|
| 1 | +# Copyright 2022 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import logging |
| 16 | +import os |
| 17 | +from pathlib import Path |
| 18 | +from typing import Any, Callable, Dict, Optional, Union |
| 19 | + |
| 20 | +from requests.exceptions import ConnectionError as RequestsConnectionError |
| 21 | +from transformers import AutoTokenizer |
| 22 | +from transformers.utils import is_torch_available |
| 23 | + |
| 24 | +from optimum.exporters import TasksManager |
| 25 | +from optimum.exporters.onnx import __main__ as optimum_main |
| 26 | +from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast |
| 27 | +from optimum.utils import DEFAULT_DUMMY_SHAPES |
| 28 | +from optimum.utils.save_utils import maybe_save_preprocessors |
| 29 | + |
| 30 | +from .convert import export_models |
| 31 | + |
| 32 | + |
| 33 | +OV_XML_FILE_NAME = "openvino_model.xml" |
| 34 | + |
| 35 | +logger = logging.getLogger(__name__) |
| 36 | + |
| 37 | +if is_torch_available(): |
| 38 | + import torch |
| 39 | + |
| 40 | + |
| 41 | +def main_export( |
| 42 | + model_name_or_path: str, |
| 43 | + output: Union[str, Path], |
| 44 | + task: str = "auto", |
| 45 | + device: str = "cpu", |
| 46 | + fp16: Optional[bool] = False, |
| 47 | + framework: Optional[str] = None, |
| 48 | + cache_dir: Optional[str] = None, |
| 49 | + trust_remote_code: bool = False, |
| 50 | + pad_token_id: Optional[int] = None, |
| 51 | + subfolder: str = "", |
| 52 | + revision: str = "main", |
| 53 | + force_download: bool = False, |
| 54 | + local_files_only: bool = False, |
| 55 | + use_auth_token: Optional[Union[bool, str]] = None, |
| 56 | + model_kwargs: Optional[Dict[str, Any]] = None, |
| 57 | + custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, |
| 58 | + fn_get_submodels: Optional[Callable] = None, |
| 59 | + **kwargs_shapes, |
| 60 | +): |
| 61 | + """ |
| 62 | + Full-suite OpenVINO export. |
| 63 | +
|
| 64 | + Args: |
| 65 | + > Required parameters |
| 66 | +
|
| 67 | + model_name_or_path (`str`): |
| 68 | + Model ID on huggingface.co or path on disk to the model repository to export. |
| 69 | + output (`Union[str, Path]`): |
| 70 | + Path indicating the directory where to store the generated ONNX model. |
| 71 | +
|
| 72 | + > Optional parameters |
| 73 | +
|
| 74 | + task (`Optional[str]`, defaults to `None`): |
| 75 | + The task to export the model for. If not specified, the task will be auto-inferred based on the model. For decoder models, |
| 76 | + use `xxx-with-past` to export the model using past key values in the decoder. |
| 77 | + device (`str`, defaults to `"cpu"`): |
| 78 | + The device to use to do the export. Defaults to "cpu". |
| 79 | + fp16 (`Optional[bool]`, defaults to `"False"`): |
| 80 | + Use half precision during the export. PyTorch-only, requires `device="cuda"`. |
| 81 | + framework (`Optional[str]`, defaults to `None`): |
| 82 | + The framework to use for the ONNX export (`"pt"` or `"tf"`). If not provided, will attempt to automatically detect |
| 83 | + the framework for the checkpoint. |
| 84 | + cache_dir (`Optional[str]`, defaults to `None`): |
| 85 | + Path indicating where to store cache. The default Hugging Face cache path will be used by default. |
| 86 | + trust_remote_code (`bool`, defaults to `False`): |
| 87 | + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories |
| 88 | + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the |
| 89 | + model repository. |
| 90 | + pad_token_id (`Optional[int]`, defaults to `None`): |
| 91 | + This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. |
| 92 | + subfolder (`str`, defaults to `""`): |
| 93 | + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can |
| 94 | + specify the folder name here. |
| 95 | + revision (`str`, defaults to `"main"`): |
| 96 | + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. |
| 97 | + force_download (`bool`, defaults to `False`): |
| 98 | + Whether or not to force the (re-)download of the model weights and configuration files, overriding the |
| 99 | + cached versions if they exist. |
| 100 | + local_files_only (`Optional[bool]`, defaults to `False`): |
| 101 | + Whether or not to only look at local files (i.e., do not try to download the model). |
| 102 | + use_auth_token (`Optional[str]`, defaults to `None`): |
| 103 | + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated |
| 104 | + when running `transformers-cli login` (stored in `~/.huggingface`). |
| 105 | + model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): |
| 106 | + Experimental usage: keyword arguments to pass to the model during |
| 107 | + the export. This argument should be used along the `custom_onnx_configs` argument |
| 108 | + in case, for example, the model inputs/outputs are changed (for example, if |
| 109 | + `model_kwargs={"output_attentions": True}` is passed). |
| 110 | + custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`): |
| 111 | + Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model). |
| 112 | + fn_get_submodels (`Optional[Callable]`, defaults to `None`): |
| 113 | + Experimental usage: Override the default submodels that are used at the export. This is |
| 114 | + especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. |
| 115 | + **kwargs_shapes (`Dict`): |
| 116 | + Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. |
| 117 | +
|
| 118 | + Example usage: |
| 119 | + ```python |
| 120 | + >>> from optimum.exporters.openvino import main_export |
| 121 | +
|
| 122 | + >>> main_export("gpt2", output="gpt2_onnx/") |
| 123 | + ``` |
| 124 | + """ |
| 125 | + output = Path(output) |
| 126 | + if not output.exists(): |
| 127 | + output.mkdir(parents=True) |
| 128 | + |
| 129 | + original_task = task |
| 130 | + task = TasksManager.map_from_synonym(task) |
| 131 | + |
| 132 | + framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) |
| 133 | + |
| 134 | + # get the shapes to be used to generate dummy inputs |
| 135 | + input_shapes = {} |
| 136 | + for input_name in DEFAULT_DUMMY_SHAPES.keys(): |
| 137 | + input_shapes[input_name] = ( |
| 138 | + kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] |
| 139 | + ) |
| 140 | + |
| 141 | + torch_dtype = None if fp16 is False else torch.float16 |
| 142 | + |
| 143 | + if task == "auto": |
| 144 | + try: |
| 145 | + task = TasksManager.infer_task_from_model(model_name_or_path) |
| 146 | + except KeyError as e: |
| 147 | + raise KeyError( |
| 148 | + f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" |
| 149 | + ) |
| 150 | + except RequestsConnectionError as e: |
| 151 | + raise RequestsConnectionError( |
| 152 | + f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" |
| 153 | + ) |
| 154 | + |
| 155 | + model = TasksManager.get_model_from_task( |
| 156 | + task, |
| 157 | + model_name_or_path, |
| 158 | + subfolder=subfolder, |
| 159 | + revision=revision, |
| 160 | + cache_dir=cache_dir, |
| 161 | + use_auth_token=use_auth_token, |
| 162 | + local_files_only=local_files_only, |
| 163 | + force_download=force_download, |
| 164 | + trust_remote_code=trust_remote_code, |
| 165 | + framework=framework, |
| 166 | + torch_dtype=torch_dtype, |
| 167 | + device=device, |
| 168 | + ) |
| 169 | + |
| 170 | + custom_architecture = False |
| 171 | + is_stable_diffusion = "stable-diffusion" in task |
| 172 | + model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") |
| 173 | + |
| 174 | + if not is_stable_diffusion: |
| 175 | + if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE: |
| 176 | + raise ValueError( |
| 177 | + f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. " |
| 178 | + f"If you want to support {model_type} please propose a PR or open up an issue." |
| 179 | + ) |
| 180 | + if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task( |
| 181 | + task, exporter="onnx" |
| 182 | + ): |
| 183 | + custom_architecture = True |
| 184 | + |
| 185 | + if custom_architecture and custom_onnx_configs is None: |
| 186 | + raise ValueError( |
| 187 | + "Trying to export a model with a custom architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models." |
| 188 | + ) |
| 189 | + |
| 190 | + if custom_architecture and original_task == "auto": |
| 191 | + raise ValueError( |
| 192 | + f'Automatic task detection is not supported with custom architectures. Please specify the `task` argument. Suggestion: task="{task}" (or task="{task}-with-past" if the model is decoder-based and supports KV cache)' |
| 193 | + ) |
| 194 | + |
| 195 | + if ( |
| 196 | + not custom_architecture |
| 197 | + and not is_stable_diffusion |
| 198 | + and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx") |
| 199 | + ): |
| 200 | + if original_task == "auto": # Make -with-past the default if --task was not explicitely specified |
| 201 | + task = task + "-with-past" |
| 202 | + else: |
| 203 | + logger.info( |
| 204 | + f"The task `{task}` was manually specified, and past key values will not be reused in the decoding." |
| 205 | + f" if needed, please pass `--task {task}-with-past` to export using the past key values." |
| 206 | + ) |
| 207 | + |
| 208 | + if original_task == "auto": |
| 209 | + synonyms_for_task = sorted(TasksManager.synonyms_for_task(task)) |
| 210 | + if synonyms_for_task: |
| 211 | + synonyms_for_task = ", ".join(synonyms_for_task) |
| 212 | + possible_synonyms = f" (possible synonyms are: {synonyms_for_task})" |
| 213 | + else: |
| 214 | + possible_synonyms = "" |
| 215 | + logger.info(f"Automatic task detection to {task}{possible_synonyms}.") |
| 216 | + onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs( |
| 217 | + model=model, |
| 218 | + task=task, |
| 219 | + monolith=False, |
| 220 | + custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, |
| 221 | + custom_architecture=custom_architecture, |
| 222 | + fn_get_submodels=fn_get_submodels, |
| 223 | + _variant="default", |
| 224 | + ) |
| 225 | + |
| 226 | + if not is_stable_diffusion: |
| 227 | + needs_pad_token_id = ( |
| 228 | + isinstance(onnx_config, OnnxConfigWithPast) |
| 229 | + and getattr(model.config, "pad_token_id", None) is None |
| 230 | + and task in ["text-classification"] |
| 231 | + ) |
| 232 | + if needs_pad_token_id: |
| 233 | + if pad_token_id is not None: |
| 234 | + model.config.pad_token_id = pad_token_id |
| 235 | + else: |
| 236 | + try: |
| 237 | + tok = AutoTokenizer.from_pretrained(model_name_or_path) |
| 238 | + model.config.pad_token_id = tok.pad_token_id |
| 239 | + except Exception: |
| 240 | + raise ValueError( |
| 241 | + "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" |
| 242 | + ) |
| 243 | + # Saving the model config and preprocessor as this is needed sometimes. |
| 244 | + model.config.save_pretrained(output) |
| 245 | + generation_config = getattr(model, "generation_config", None) |
| 246 | + if generation_config is not None: |
| 247 | + generation_config.save_pretrained(output) |
| 248 | + maybe_save_preprocessors(model_name_or_path, output) |
| 249 | + |
| 250 | + if model.config.is_encoder_decoder and task.startswith("text-generation"): |
| 251 | + raise ValueError( |
| 252 | + f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report" |
| 253 | + f"at https://github.com/huggingface/optimum, if --task was explicitely passed, make sure you selected the right task for the model," |
| 254 | + f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`." |
| 255 | + ) |
| 256 | + |
| 257 | + files_subpaths = None |
| 258 | + else: |
| 259 | + # save the subcomponent configuration |
| 260 | + for model_name in models_and_onnx_configs: |
| 261 | + subcomponent = models_and_onnx_configs[model_name][0] |
| 262 | + if hasattr(subcomponent, "save_config"): |
| 263 | + subcomponent.save_config(output / model_name) |
| 264 | + elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"): |
| 265 | + subcomponent.config.save_pretrained(output / model_name) |
| 266 | + |
| 267 | + files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_onnx_configs] |
| 268 | + |
| 269 | + # Saving the additional components needed to perform inference. |
| 270 | + model.scheduler.save_pretrained(output.joinpath("scheduler")) |
| 271 | + |
| 272 | + feature_extractor = getattr(model, "feature_extractor", None) |
| 273 | + if feature_extractor is not None: |
| 274 | + feature_extractor.save_pretrained(output.joinpath("feature_extractor")) |
| 275 | + |
| 276 | + tokenizer = getattr(model, "tokenizer", None) |
| 277 | + if tokenizer is not None: |
| 278 | + tokenizer.save_pretrained(output.joinpath("tokenizer")) |
| 279 | + |
| 280 | + tokenizer_2 = getattr(model, "tokenizer_2", None) |
| 281 | + if tokenizer_2 is not None: |
| 282 | + tokenizer_2.save_pretrained(output.joinpath("tokenizer_2")) |
| 283 | + |
| 284 | + model.save_config(output) |
| 285 | + |
| 286 | + export_models( |
| 287 | + models_and_onnx_configs=models_and_onnx_configs, |
| 288 | + output_dir=output, |
| 289 | + output_names=files_subpaths, |
| 290 | + input_shapes=input_shapes, |
| 291 | + device=device, |
| 292 | + model_kwargs=model_kwargs, |
| 293 | + ) |
0 commit comments