Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,33 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return {"input_features": {0: "batch_size", 1: "sequence_classification"}}


class MoonshineOnnxConfig(AudioToTextOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::triu' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}

if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_values"] = {0: "batch_size", 1: "num_samples"}

if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
self.add_past_key_values(common_inputs, direction="inputs")
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}

return common_inputs


class WhisperOnnxConfig(AudioToTextOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Whisper now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

Expand All @@ -1802,11 +1829,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.

if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
if is_transformers_version(">=", "4.43.0"):
# since https://github.com/huggingface/transformers/pull/31166
common_inputs["cache_position"] = {0: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs:
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
return common_inputs
Expand Down
81 changes: 64 additions & 17 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,51 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step):
return result


UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]
# An ONNX-export-compatible version of `tensor.repeat_interleave`.
# Without this, we get the following error: https://github.com/pytorch/pytorch/issues/145100
# NOTE: This implementation is only necessary for export with dynamo=False (dynamo=True works correctly).
# and can be removed once Optimum switches to dynamo-based exports
def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None):
"""
Custom implementation of torch.repeat_interleave without using torch.repeat_interleave.

Args:
input_tensor (torch.Tensor): The input tensor.
repeats (int or torch.Tensor): The number of repetitions for each element.
dim (int, optional): The dimension along which to repeat. Defaults to None.

Returns:
torch.Tensor: The repeated tensor.
"""
if isinstance(repeats, int) or (torch.is_tensor(repeats) and repeats.dim() == 0):
if dim is None:
return input_tensor.flatten().unsqueeze(1).expand(-1, repeats).flatten()
repeats = torch.full((input_tensor.shape[dim],), repeats, dtype=torch.long, device=input_tensor.device)

if dim is None:
return onnx_compatible_repeat_interleave(input_tensor.flatten(), repeats, 0)

if dim != 0:
input_tensor = input_tensor.transpose(0, dim)

# Create expand mask
max_repeats = repeats.max()
expanded = input_tensor.unsqueeze(1).expand(-1, max_repeats, *input_tensor.shape[1:])
mask = torch.arange(max_repeats, device=input_tensor.device) < repeats.unsqueeze(1)
result = expanded[mask]

if dim != 0:
result = result.transpose(0, dim)

return result


UNSUPPORTED_OPS_PATCHING_SPEC = [
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave),
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
Comment on lines +201 to +202

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GREAT ! Thanks for omitting this !

]
CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)]


Expand Down Expand Up @@ -239,7 +283,7 @@ def patched_forward(*args, **kwargs):
# contains the output names of the model. In the case of Timm classification models, the output
# is of type tensor. By default, it is assumed that the output names mentioned in the ONNX config
# match the outputs in order.
filterd_outputs = {}
filtered_outputs = {}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch ! I'm embarrassed by the amount of times I've modified this file without seeing this x)

if isinstance(outputs, dict):
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
Expand All @@ -248,10 +292,10 @@ def patched_forward(*args, **kwargs):
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filterd_outputs[name] = value
filtered_outputs[name] = value
elif isinstance(outputs, (list, tuple)):
outputs_list = list(config.outputs.keys())
filterd_outputs = dict(zip(outputs_list, outputs))
filtered_outputs = dict(zip(outputs_list, outputs))
else:
if len(config.outputs) > 1:
num_outputs = len(config.outputs)
Expand All @@ -261,15 +305,15 @@ def patched_forward(*args, **kwargs):
)
else:
name = list(config.outputs.keys())[0]
filterd_outputs[name] = outputs
filtered_outputs[name] = outputs
name = list(config.outputs.keys())[0]
filterd_outputs[name] = outputs
filtered_outputs[name] = outputs

if is_transformers_version(">=", "4.48"):
if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
if isinstance(filtered_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)):
filtered_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()

return filterd_outputs
return filtered_outputs

self.patched_forward = patched_forward

Expand Down Expand Up @@ -325,15 +369,18 @@ def __init__(
if model.config.model_type == "pix2struct" and allow_past_in_outputs:
model.config.text_config.use_cache = True

@functools.wraps(self.orig_forward)
# Re-use the patched forward method from the parent class
self.super_patched_forward = self.patched_forward

@functools.wraps(self.super_patched_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
signature = inspect.signature(self.super_patched_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

outputs = self.orig_forward(*args, **kwargs)
outputs = self.super_patched_forward(*args, **kwargs)

# Filter out cross attention past key values output from the decoder using KV cache, as they are constants.
filterd_outputs = {}
filtered_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
Expand All @@ -346,17 +393,17 @@ def patched_forward(*args, **kwargs):
# Who cares about the encoder outputs in the decoder?
continue
else:
filterd_outputs[name] = value
filtered_outputs[name] = value
else:
if self.real_config._behavior == "monolith" or (
self.real_config._behavior == "decoder"
and (self.real_config.is_merged or not self.real_config.use_past_in_inputs)
):
filterd_outputs[name] = value
filtered_outputs[name] = value
elif self.real_config._behavior == "decoder" and self.real_config.use_past_in_inputs:
# The filtering happens here. The decoder with use_past_in_inputs=True corresponds to the autoregressive one.
filterd_outputs[name] = tuple([v[:2] for v in value])
return filterd_outputs
filtered_outputs[name] = tuple([v[:2] for v in value])
return filtered_outputs

self.patched_forward = patched_forward

Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,13 @@ class TasksManager:
"token-classification",
onnx="ModernBertOnnxConfig",
),
"moonshine": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
onnx="MoonshineOnnxConfig",
),
"mpnet": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"mobilenet-v1": "hf-internal-testing/tiny-random-MobileNetV1Model",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM",
"moonshine": "hf-internal-testing/tiny-random-MoonshineForConditionalGeneration",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
"mt5": "lewtun/tiny-random-mt5",
Expand Down Expand Up @@ -271,6 +272,7 @@
"mobilenet_v2": "google/mobilenet_v2_0.35_96",
"mobilevit": "apple/mobilevit-small",
"modernbert": "answerdotai/ModernBERT-base",
"moonshine": "UsefulSensors/moonshine-tiny",
"mpt": "mosaicml/mpt-7b",
"mt5": "google/mt5-small",
"musicgen": "facebook/musicgen-small",
Expand Down