Skip to content

Commit e86ad91

Browse files
committed
docstrings and typehints
1 parent e8c0490 commit e86ad91

File tree

5 files changed

+142
-13
lines changed

5 files changed

+142
-13
lines changed

optimum/exporters/openvino/convert.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
from openvino.runtime.utils.types import get_element_type
2626
from openvino.tools.ovc import convert_model
2727
from optimum.exporters.onnx.base import OnnxConfig
28-
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed, export_tensorflow as export_tensorflow_onnx
28+
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
2929
from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx
30+
from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx
3031
from optimum.utils import is_diffusers_available
3132

3233
from ...intel.openvino.utils import OV_XML_FILE_NAME
@@ -119,6 +120,20 @@ def export(
119120

120121

121122
def export_tensorflow(model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, opset: int, output: Path):
123+
"""
124+
Export the TensorFlow model to OpenVINO format.
125+
126+
Args:
127+
model (Union[): The model to export.
128+
config (OnnxConfig): The configuration of the model.
129+
opset (int): The ONNX opset version to use.
130+
output (Path): The path to save the model.
131+
132+
Returns:
133+
input_names: list of input names from ONNX configuration
134+
output_names: list of output names from ONNX configuration
135+
bool: True if the model was exported successfully.
136+
"""
122137
onnx_path = Path(output).with_suffix(".onnx")
123138
input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path)
124139
ov_model = convert_model(str(onnx_path))
@@ -139,6 +154,30 @@ def export_pytorch_via_onnx(
139154
input_shapes: Optional[Dict] = None,
140155
model_kwargs: Optional[Dict[str, Any]] = None,
141156
):
157+
"""
158+
Exports a PyTorch model to an OpenVINO Intermediate Representation via ONNX export.
159+
160+
Args:
161+
model ([`PreTrainedModel`]):
162+
The model to export.
163+
config ([`~exporters.onnx.config.OnnxConfig`]):
164+
The configuration associated with the exported model.
165+
opset (`int`):
166+
The version of the ONNX operator set to use.
167+
output (`Path`):
168+
Directory to store the exported model.
169+
device (`str`, defaults to `"cpu"`):
170+
The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
171+
export on CUDA devices.
172+
input_shapes (`optional[Dict]`, defaults to `None`):
173+
If specified, allows to use specific shapes for the example input provided to the exporter.
174+
model_kwargs (optional[Dict[str, Any]], defaults to `None`):
175+
Additional kwargs for model export
176+
177+
Returns:
178+
`Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from
179+
the ONNX configuration and boolean flag - was legacy ONNX path were applied to model or not.
180+
"""
142181
import torch
143182

144183
output = Path(output)
@@ -186,10 +225,12 @@ def export_pytorch(
186225
export on CUDA devices.
187226
input_shapes (`optional[Dict]`, defaults to `None`):
188227
If specified, allows to use specific shapes for the example input provided to the exporter.
228+
model_kwargs (optional[Dict[str, Any]], defaults to `None`):
229+
Additional kwargs for model export
189230
190231
Returns:
191-
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
192-
the ONNX configuration.
232+
`Tuple[List[str], List[str], bool]`: A tuple with an ordered list of the model's inputs, and the named inputs from
233+
the ONNX configuration and boolean flag - was legacy ONNX path were applied to model or not.
193234
"""
194235
import torch
195236
from torch.utils._pytree import tree_map
@@ -299,6 +340,28 @@ def export_models(
299340
input_shapes: Optional[Dict] = None,
300341
model_kwargs: Optional[Dict[str, Any]] = None,
301342
) -> Tuple[List[List[str]], List[List[str]]]:
343+
"""
344+
Export the models to OpenVINO IR format
345+
346+
Args:
347+
models_and_onnx_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]):
348+
output_dir (Path): output directory for saving models
349+
opset (Optional[int], optional, Default to None): ONNX export opset
350+
output_names (Optional[List[str]], optional, Defaults to None): model output names
351+
device (str, optional, Defaults to "cpu"):
352+
The device on which the model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
353+
export on CUDA devices.
354+
input_shapes (Optional[Dict], optional, Defaults to None):
355+
If specified, allows to use specific shapes for the example input provided to the exporter.
356+
model_kwargs (Optional[Dict[str, Any]], optional):
357+
Additional kwargs for model export
358+
359+
Raises:
360+
ValueError: if custom names set not equal of number of models
361+
362+
Returns:
363+
list of input_names and output_names from ONNX configuration
364+
"""
302365
outputs = []
303366

304367
if output_names is not None and len(output_names) != len(models_and_onnx_configs):

optimum/exporters/openvino/utils.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,48 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Any, Dict, List, Tuple, Union
16+
1517
from transformers.utils import is_torch_available
1618

1719
from openvino.runtime import PartialShape
18-
19-
from ...intel.utils.import_utils import is_nncf_available
20+
from optimum.utils import is_diffusers_available
2021

2122

2223
if is_torch_available():
2324
import torch
2425
import torch.nn as nn
26+
from transformers.modeling_utils import PreTrainedModel
27+
28+
if is_diffusers_available():
29+
from diffusers import ModelMixin
30+
2531

32+
def is_torch_model(model: Union["PreTrainedModel", "ModelMixin"]):
33+
"""
34+
Checks whether the model is a torch model.
2635
27-
def is_torch_model(model):
36+
Args:
37+
model (Union[PretrainedModel, ModelMixin]): The model to check.
38+
39+
Returns:
40+
bool: True if the model is a torch model.
41+
"""
2842
if not is_torch_available():
2943
return False
3044
return isinstance(model, nn.Module)
3145

3246

33-
def flattenize_inputs(inputs):
47+
def flattenize_inputs(inputs: List[Any]):
48+
"""
49+
Flatten the inputs into a list.
50+
51+
Args:
52+
inputs (List[Any]): The inputs to flatten.
53+
54+
Returns:
55+
List[Any]: The flattened inputs.
56+
"""
3457
flatten_inputs = []
3558
for input_data in inputs:
3659
if input_data is None:
@@ -42,8 +65,27 @@ def flattenize_inputs(inputs):
4265
return flatten_inputs
4366

4467

45-
def remove_none_from_dummy_inputs(dummy_inputs):
46-
def remove_none_from_list_tuple(item):
68+
def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]):
69+
"""
70+
Removes None values from the dictionary.
71+
72+
Args:
73+
dummy_inputs (Dict[str, Any]): Dictionary with None values.
74+
Returns:
75+
upd_dummy (Dict[str, Any]): updated dictionary with removed None values
76+
dict_dummy (List[Tuple[str, List[str]]]): list of inputs represented as dictionary provided as pair name and list of nested keys
77+
"""
78+
79+
def remove_none_from_list_tuple(item: Union[List[Any], Tuple[Any]]):
80+
"""
81+
Removes None values from a list or tuple.
82+
83+
Args:
84+
item (list or tuple): The list or tuple to remove None values from.
85+
86+
Returns:
87+
list or tuple: The list or tuple with None values removed.
88+
"""
4789
new_item = [i for i in item if i is not None]
4890
return type(item)(new_item)
4991

@@ -63,7 +105,18 @@ def remove_none_from_list_tuple(item):
63105
return upd_dummy, dict_dummy
64106

65107

66-
def get_input_shapes(dummy_inputs, inputs):
108+
def get_input_shapes(dummy_inputs: Dict[str, Any], inputs: Dict[str, Any]):
109+
"""
110+
Resolves input shapes based on dynamic axes from input config and dummy input shapes
111+
112+
Args:
113+
dummy_inputs (Dict[str, Any]): A dictionary of dummy inputs.
114+
inputs (Dict[str, Any]): A dictionary of input tensors.
115+
116+
Returns:
117+
input_info: List of input info for conversion
118+
119+
"""
67120
input_info = []
68121
for input_name, data in dummy_inputs.items():
69122
if isinstance(data, (tuple, list, dict)):
@@ -78,6 +131,9 @@ def get_input_shapes(dummy_inputs, inputs):
78131

79132

80133
def clear_class_registry():
134+
"""
135+
Removes Torchscript cached modules
136+
"""
81137
torch._C._jit_clear_class_registry()
82138
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
83139
torch.jit._state._clear_class_state()

optimum/intel/openvino/modeling_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from transformers import PretrainedConfig
2626
from transformers.file_utils import add_start_docstrings
2727

28-
from optimum.exporters.tasks import TasksManager
2928
from optimum.exporters.onnx.base import OnnxConfig
29+
from optimum.exporters.tasks import TasksManager
3030
from optimum.modeling_base import OptimizedModel
3131

3232
from ...exporters.openvino import export

optimum/intel/utils/modeling_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Tuple
1616

1717
import torch
18+
from transformers.modeling_utils import PreTrainedModel
1819

1920

2021
# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask
@@ -91,7 +92,16 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
9192
return combined_attention_mask
9293

9394

94-
def patch_decoder_attention_mask(model):
95+
def patch_decoder_attention_mask(model: "PreTrainedModel"):
96+
"""
97+
Apply patch on decoder with past model forward to resolve first inference based on model architecture
98+
99+
Args:
100+
model (PretrainedModel): The model to patch.
101+
102+
Returns:
103+
model with applied patch
104+
"""
95105
if model.config.model_type == "bloom":
96106
model.transformer._prepare_attn_mask = _prepare_attn_mask
97107
elif model.config.model_type == "llama":

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"onnxruntime<1.15.0",
4343
],
4444
"openvino": ["openvino==2023.1.0.dev20230811", "onnx", "onnxruntime"],
45-
"nncf": ["nncf @ git+https://github.com/openvinotoolkit/nncf.git"],
45+
"nncf": ["nncf @ git+https://github.com/openvinotoolkit/nncf.git", "transformers<4.32.0"],
4646
"ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"],
4747
"diffusers": ["diffusers", "invisible-watermark>=0.2.0"],
4848
"quality": QUALITY_REQUIRE,

0 commit comments

Comments
 (0)