|
19 | 19 | from pathlib import Path |
20 | 20 | from typing import Any, Dict, List, Optional, Tuple, Union |
21 | 21 |
|
22 | | -from openvino.runtime import PartialShape, serialize |
23 | | -from openvino.tools.ovc import convert_model |
24 | | -from openvino.runtime.utils.types import get_element_type |
25 | 22 | from transformers.utils import is_tf_available, is_torch_available |
26 | 23 |
|
| 24 | +from openvino.runtime import PartialShape, serialize |
| 25 | +from openvino.runtime.utils.types import get_element_type |
| 26 | +from openvino.tools.ovc import convert_model |
27 | 27 | from optimum.exporters.onnx.base import OnnxConfig |
28 | 28 | from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed, export_tensorflow |
29 | 29 | from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx |
30 | 30 | from optimum.utils import is_diffusers_available |
31 | 31 |
|
32 | | -from ...intel.openvino.utils import OV_XML_FILE_NAME, ONNX_WEIGHTS_NAME |
33 | | -from .utils import clear_class_registry, remove_none_from_dummy_inputs, flattenize_inputs, get_input_shapes, is_torch_model |
| 32 | +from ...intel.openvino.utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME |
| 33 | +from .utils import ( |
| 34 | + clear_class_registry, |
| 35 | + flattenize_inputs, |
| 36 | + get_input_shapes, |
| 37 | + remove_none_from_dummy_inputs, |
| 38 | +) |
34 | 39 |
|
35 | 40 |
|
36 | 41 | logger = logging.getLogger(__name__) |
@@ -197,7 +202,11 @@ def ts_patched_forward(*args, **kwargs): |
197 | 202 | except Exception: |
198 | 203 | model.config.torchscript = False |
199 | 204 | model.config.return_dict = True |
200 | | - onnx_output = output.with_suffix(".onnx") if not output.name != OV_XML_FILE_NAME else output.parent / ONNX_WEIGHTS_NAME |
| 205 | + onnx_output = ( |
| 206 | + output.with_suffix(".onnx") |
| 207 | + if not output.name != OV_XML_FILE_NAME |
| 208 | + else output.parent / ONNX_WEIGHTS_NAME |
| 209 | + ) |
201 | 210 | input_names, output_names = export_pytorch_to_onnx( |
202 | 211 | model, config, opset, onnx_output, device, input_shapes, model_kwargs |
203 | 212 | ) |
|
0 commit comments