@@ -184,22 +184,6 @@ def _get_submodels_for_export_diffusion(
184184 return models_for_export
185185
186186
187- def _get_submodels_for_export_decoder (
188- model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
189- use_past : bool ,
190- legacy : bool = False ,
191- ) -> Dict [str , Union ["PreTrainedModel" , "TFPreTrainedModel" ]]:
192- """
193- Returns the decoder part of the model.
194- """
195- models_for_export = {DECODER_NAME if legacy else "model" : model }
196-
197- if legacy and use_past :
198- models_for_export [DECODER_WITH_PAST_NAME ] = model
199-
200- return models_for_export
201-
202-
203187def _get_submodels_for_export_encoder_decoder (
204188 model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], use_past : bool
205189) -> Dict [str , Union ["PreTrainedModel" , "TFPreTrainedModel" ]]:
@@ -251,72 +235,6 @@ def get_encoder_decoder_models_for_export(
251235 return models_for_export
252236
253237
254- def get_decoder_models_for_export (
255- model : Union ["PreTrainedModel" , "TFPreTrainedModel" ],
256- config : "ExporterConfig" ,
257- legacy : bool = False ,
258- ) -> Dict [str , Tuple [Union ["PreTrainedModel" , "TFPreTrainedModel" ], "ExporterConfig" ]]:
259- """
260- Returns two versions of the decoder that can be used together to perform fast generation:
261-
262- 1. The first one takes regular inputs, and outputs the result along with past key/values.
263- 2. The second one takes regular inputs and past key/values, and outputs the result along with the updated past
264- key/values.
265-
266-
267- Args:
268- model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
269- The model to export.
270- config ([`~exporters.base.ExporterConfig`]):
271- The export configuration associated with the exported model.
272-
273- Returns:
274- `Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], ExporterConfig]]: A Dict containing the model and
275- export configs for the encoder and decoder parts of the model.
276- """
277-
278- models_for_export = _get_submodels_for_export_decoder (model , use_past = config .use_past , legacy = legacy )
279-
280- export_kwargs = {
281- "task" : config .task ,
282- "float_dtype" : config .float_dtype ,
283- "int_dtype" : config .int_dtype ,
284- "legacy" : legacy ,
285- }
286-
287- if legacy :
288- export_config = config .__class__ (
289- model .config ,
290- use_past = config .use_past ,
291- use_past_in_inputs = False ,
292- ** export_kwargs ,
293- )
294- models_for_export [DECODER_NAME ] = (models_for_export [DECODER_NAME ], export_config )
295-
296- if config .use_past :
297- export_config_with_past = config .__class__ (
298- model .config ,
299- use_past = True ,
300- use_past_in_inputs = True ,
301- ** export_kwargs ,
302- )
303- models_for_export [DECODER_WITH_PAST_NAME ] = (
304- models_for_export [DECODER_WITH_PAST_NAME ],
305- export_config_with_past ,
306- )
307-
308- else :
309- export_config = config .__class__ (
310- model .config ,
311- use_past = config .use_past ,
312- use_past_in_inputs = config .use_past ,
313- ** export_kwargs ,
314- )
315- models_for_export ["model" ] = (models_for_export ["model" ], export_config )
316-
317- return models_for_export
318-
319-
320238def get_diffusion_models_for_export (
321239 pipeline : "DiffusionPipeline" ,
322240 int_dtype : str = "int64" ,
@@ -432,12 +350,12 @@ def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrained
432350 }
433351
434352 text_encoder_config = config .__class__ (
435- model .config , task = config .task , legacy = False , model_part = "text_encoder" , variant = config .variant
353+ model .config , task = config .task , model_part = "text_encoder" , variant = config .variant
436354 )
437355 models_for_export ["text_encoder" ] = (models_for_export ["text_encoder" ], text_encoder_config )
438356
439357 audio_encoder_config = config .__class__ (
440- model .config , task = config .task , legacy = False , model_part = "encodec_decode" , variant = config .variant
358+ model .config , task = config .task , model_part = "encodec_decode" , variant = config .variant
441359 )
442360 models_for_export ["encodec_decode" ] = (models_for_export ["encodec_decode" ], audio_encoder_config )
443361
@@ -455,7 +373,7 @@ def get_musicgen_models_for_export(model: Union["PreTrainedModel", "TFPreTrained
455373 )
456374
457375 build_delay_pattern_mask_config = config .__class__ (
458- model .config , task = config .task , legacy = False , model_part = "build_delay_pattern_mask" , variant = config .variant
376+ model .config , task = config .task , model_part = "build_delay_pattern_mask" , variant = config .variant
459377 )
460378 models_for_export ["build_delay_pattern_mask" ] = (
461379 models_for_export ["build_delay_pattern_mask" ],
@@ -482,14 +400,14 @@ def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel
482400 models_for_export = _get_submodels_for_export_sam (model , config .variant )
483401
484402 if config .variant == "monolith" :
485- export_config = config .__class__ (model .config , task = config .task , legacy = config . legacy )
403+ export_config = config .__class__ (model .config , task = config .task )
486404 models_for_export ["model" ] = (models_for_export ["model" ], export_config )
487405 else :
488406 vision_encoder_export_config = config .__class__ (
489- model .config , task = config .task , variant = config .variant , vision_encoder = True , legacy = config . legacy
407+ model .config , task = config .task , variant = config .variant , vision_encoder = True
490408 )
491409 prompt_encoder_mask_decoder_export_config = config .__class__ (
492- model .config , task = config .task , variant = config .variant , vision_encoder = False , legacy = config . legacy
410+ model .config , task = config .task , variant = config .variant , vision_encoder = False
493411 )
494412 models_for_export ["vision_encoder" ] = (models_for_export ["vision_encoder" ], vision_encoder_export_config )
495413 models_for_export ["prompt_encoder_mask_decoder" ] = (
@@ -547,7 +465,6 @@ def get_speecht5_models_for_export(
547465 behavior = config ._behavior , # Irrelevant here.
548466 preprocessors = config ._preprocessors ,
549467 is_postnet_and_vocoder = True ,
550- legacy = config .legacy ,
551468 )
552469 postnet_and_vocoder_export_config .variant = config .variant
553470 models_for_export ["decoder_postnet_and_vocoder" ] = (
@@ -592,7 +509,6 @@ def _get_submodels_and_export_configs(
592509 float_dtype : str = "fp32" ,
593510 fn_get_submodels : Optional [Callable ] = None ,
594511 preprocessors : Optional [List [Any ]] = None ,
595- legacy : bool = False ,
596512 model_kwargs : Optional [Dict ] = None ,
597513 exporter : str = "onnx" ,
598514):
@@ -611,7 +527,6 @@ def _get_submodels_and_export_configs(
611527 int_dtype = int_dtype ,
612528 float_dtype = float_dtype ,
613529 preprocessors = preprocessors ,
614- legacy = legacy ,
615530 )
616531
617532 export_config .variant = _variant
@@ -622,13 +537,11 @@ def _get_submodels_and_export_configs(
622537
623538 # TODO: this succession of if/else strongly suggests a refactor is needed.
624539 if (
625- model . config . is_encoder_decoder
626- and task . startswith ( TasksManager . _ENCODER_DECODER_TASKS )
540+ task . startswith ( TasksManager . _ENCODER_DECODER_TASKS )
541+ and model . config . is_encoder_decoder
627542 and not monolith
628543 ):
629544 models_and_export_configs = get_encoder_decoder_models_for_export (model , export_config )
630- elif task .startswith ("text-generation" ) and not monolith :
631- models_and_export_configs = get_decoder_models_for_export (model , export_config , legacy = legacy )
632545 elif model .config .model_type == "sam" :
633546 models_and_export_configs = get_sam_models_for_export (model , export_config )
634547 elif model .config .model_type == "speecht5" :
@@ -660,8 +573,6 @@ def _get_submodels_and_export_configs(
660573 submodels_for_export = _get_submodels_for_export_encoder_decoder (
661574 model , use_past = task .endswith ("-with-past" )
662575 )
663- elif task .startswith ("text-generation" ) and not monolith :
664- submodels_for_export = _get_submodels_for_export_decoder (model , use_past = task .endswith ("-with-past" ))
665576 else :
666577 submodels_for_export = {"model" : model }
667578
0 commit comments