2121
2222from transformers .utils import is_tf_available , is_torch_available
2323
24- from openvino .runtime import PartialShape , serialize
24+ from openvino .runtime import PartialShape , save_model
2525from openvino .runtime .utils .types import get_element_type
2626from openvino .tools .ovc import convert_model
2727from optimum .exporters .onnx .base import OnnxConfig
@@ -151,8 +151,6 @@ def export_pytorch(
151151
152152 with torch .no_grad ():
153153 model .config .return_dict = True
154- custom_patcher = type (config ).patch_model_for_export != OnnxConfig .patch_model_for_export
155- model .config .torchscript = not custom_patcher
156154 model .eval ()
157155
158156 # Check if we need to override certain configuration item
@@ -182,24 +180,30 @@ def export_pytorch(
182180 else :
183181 sig = inspect .signature (model .call )
184182
185- dummy_inputs = remove_none_from_dummy_inputs (dummy_inputs )
183+ dummy_inputs , dict_inputs = remove_none_from_dummy_inputs (dummy_inputs )
186184 input_info = get_input_shapes (dummy_inputs , inputs )
187185 try :
188- if custom_patcher :
189- patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
190- patched_forward = patcher .patched_forward
191-
192- @functools .wraps (patched_forward )
193- def ts_patched_forward (* args , ** kwargs ):
194- outputs = patched_forward (* args , ** kwargs )
195- return tuple (outputs .values ())
196-
197- patcher .patched_forward = ts_patched_forward
198- with patcher :
199- ov_model = convert_model (model , example_input = dummy_inputs , input = input_info )
200- else :
186+ patcher = config .patch_model_for_export (model , model_kwargs = model_kwargs )
187+ patched_forward = patcher .patched_forward
188+
189+ @functools .wraps (patched_forward )
190+ def ts_patched_forward (* args , ** kwargs ):
191+ for i in range (len (dict_inputs )):
192+ input_name = dict_inputs [i ][0 ]
193+ keys = dict_inputs [i ][1 ]
194+ tuple_input = kwargs [input_name ]
195+ input_dict = dict (zip (keys , tuple_input ))
196+ kwargs [input_name ] = input_dict
197+ outputs = patched_forward (* args , ** kwargs )
198+ return tuple (outputs .values ())
199+
200+ patcher .patched_forward = ts_patched_forward
201+ with patcher :
201202 ov_model = convert_model (model , example_input = dummy_inputs , input = input_info )
202203 except Exception :
204+ orig_torch_onnx_export = torch .onnx .export
205+
206+ torch .onnx .export = functools .partial (orig_torch_onnx_export , do_constant_folding = True )
203207 model .config .torchscript = False
204208 model .config .return_dict = True
205209 onnx_output = (
@@ -210,13 +214,19 @@ def ts_patched_forward(*args, **kwargs):
210214 input_names , output_names = export_pytorch_to_onnx (
211215 model , config , opset , onnx_output , device , input_shapes , model_kwargs
212216 )
217+ torch .onnx .export = orig_torch_onnx_export
213218 ov_model = convert_model (str (onnx_output ))
214- serialize (ov_model , output .parent / OV_XML_FILE_NAME if output .suffix != ".xml" else output )
219+ save_model (
220+ ov_model ,
221+ output .parent / OV_XML_FILE_NAME if output .suffix != ".xml" else output ,
222+ compress_to_fp16 = False ,
223+ )
215224 return input_names , output_names , True
216225 clear_class_registry ()
217226 ordered_dummy_inputs = {param : dummy_inputs [param ] for param in sig .parameters if param in dummy_inputs }
218227 ordered_input_names = list (inputs )
219228 flatten_inputs = flattenize_inputs (ordered_dummy_inputs .values ())
229+ ov_model .validate_nodes_and_infer_types ()
220230 for idx , out_tensor in enumerate (ov_model .outputs ):
221231 if idx < len (output_names ):
222232 out_tensor .get_tensor ().set_names ({output_names [idx ]})
@@ -233,7 +243,9 @@ def ts_patched_forward(*args, **kwargs):
233243 inp_tensor .get_node ().set_partial_shape (static_shape )
234244 inp_tensor .get_node ().set_element_type (get_element_type (inp_data .cpu ().numpy ().dtype ))
235245 ov_model .validate_nodes_and_infer_types ()
236- serialize (ov_model , output .parent / OV_XML_FILE_NAME if output .suffix != ".xml" else output )
246+ save_model (
247+ ov_model , output .parent / OV_XML_FILE_NAME if output .suffix != ".xml" else output , compress_to_fp16 = False
248+ )
237249 del model
238250 gc .collect ()
239251 return input_names , output_names , False
0 commit comments