Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d28dee8
test
IlyasMoutawwakil Jul 10, 2025
37dac0b
fix seq2seq patched sdpa
IlyasMoutawwakil Jul 10, 2025
19b3dc2
patch qwen3_moe, out_attentions, and eager_mask
IlyasMoutawwakil Jul 10, 2025
ce3809f
fix
IlyasMoutawwakil Jul 10, 2025
bd8dec6
use optimum model
IlyasMoutawwakil Jul 10, 2025
5593686
editable subpackages
IlyasMoutawwakil Jul 11, 2025
b10c340
Apply suggestions from code review
IlyasMoutawwakil Jul 15, 2025
d3bd103
smollm3 support
IlyasMoutawwakil Jul 15, 2025
f80c1c5
Merge branch 'transformers-4.53' of https://github.com/huggingface/op…
IlyasMoutawwakil Jul 15, 2025
6b33750
deprecate tensorflow onnx export and add smollm3 to export tests
IlyasMoutawwakil Jul 17, 2025
85521dc
write a more general sdpa_mask without vmap that's also vectorized an…
IlyasMoutawwakil Jul 21, 2025
dd6d0c1
better and more generic sdpa_mask_without_vmap implementation
IlyasMoutawwakil Jul 21, 2025
3dd85c1
style and fix
IlyasMoutawwakil Jul 21, 2025
f7b6ebd
fix
IlyasMoutawwakil Jul 21, 2025
c5e0165
patch find_packed_sequence_indices as it's untraceable
IlyasMoutawwakil Jul 21, 2025
3a4ea0d
fix
IlyasMoutawwakil Jul 21, 2025
74f064d
fix
IlyasMoutawwakil Jul 21, 2025
5976851
revert tests removal until refactor
IlyasMoutawwakil Jul 22, 2025
6c602ce
fix temporary hub repo import
IlyasMoutawwakil Jul 22, 2025
b43fead
fix
IlyasMoutawwakil Jul 22, 2025
9953b91
fix external data tests on windows
IlyasMoutawwakil Jul 22, 2025
310bcd1
update phi and phi3 min version
IlyasMoutawwakil Jul 22, 2025
459316d
condition modernbert optimization test
IlyasMoutawwakil Jul 22, 2025
4fc5972
get back old (pre 4.44) bloom modeling support and remove the need fo…
IlyasMoutawwakil Jul 22, 2025
4134e46
fix test was using hardcoded architecture
IlyasMoutawwakil Jul 22, 2025
2a6ef9c
unparallelize test that uses remote code
IlyasMoutawwakil Jul 22, 2025
8623353
support older versions of mpt and phi (4.36)
IlyasMoutawwakil Jul 22, 2025
a3ad9df
remove parallelism from slow tests
IlyasMoutawwakil Jul 22, 2025
77dd30f
fix vision to text pipelines test
IlyasMoutawwakil Jul 22, 2025
d41f0ea
more specific version handling for find_packed_sequence_indices
IlyasMoutawwakil Jul 23, 2025
7790092
fix
IlyasMoutawwakil Jul 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimum/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
Expand Down
2 changes: 1 addition & 1 deletion optimum/commands/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from .base import ExportCommand
from .onnx import ONNXExportCommand
Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import onnx # noqa
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from .tasks import TasksManager # noqa
from .base import ExporterConfig # noqa
7 changes: 6 additions & 1 deletion optimum/exporters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@


class ExportConfig(ABC):
pass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
logger.warning(
"The `ExportConfig` class is deprecated and will be removed in a future version. "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Please use `ExporterConfig` instead."
)


class ExporterConfig(ABC):
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ def main_export(
if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and is_transformers_version("<", "4.42"):
loading_kwargs["attn_implementation"] = "eager"

# Only eager attention implementation returns attentions
if model_kwargs is not None and model_kwargs.get("output_attentions", False):
Comment thread
echarlaix marked this conversation as resolved.
loading_kwargs["attn_implementation"] = "eager"

with DisableCompileContextManager():
model = TasksManager.get_model_from_task(
task,
Expand Down
9 changes: 3 additions & 6 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@
MgpstrModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
Qwen3MoeModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
SentenceTransformersTransformerPatcher,
SpeechT5ModelPatcher,
VisionEncoderDecoderPatcher,
VitPoseModelPatcher,
WavLMModelPatcher,
)


Expand Down Expand Up @@ -459,6 +459,7 @@ class Qwen3OnnxConfig(LlamaOnnxConfig):
)
class Qwen3MoeOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.51.0")
_MODEL_PATCHER = Qwen3MoeModelPatcher


@register_tasks_manager_onnx("gemma", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
Expand Down Expand Up @@ -1838,11 +1839,7 @@ class UniSpeechSATOnnxConfig(HubertOnnxConfig):
],
)
class WavLMOnnxConfig(HubertOnnxConfig):
DEFAULT_ONNX_OPSET = 12
# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
# due to the op torch.nn.functional.multi_head_attention_forward used for WavLM
_MODEL_PATCHER = WavLMModelPatcher
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


@register_tasks_manager_onnx("audio-spectrogram-transformer", *["feature-extraction", "audio-classification"])
Expand Down
185 changes: 142 additions & 43 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import transformers
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet

from ...utils import is_transformers_version, logging
from ...utils import is_torch_version, is_transformers_version, logging
from ._traceable_cache import TraceableCache


Expand All @@ -40,7 +40,9 @@
from transformers.cache_utils import DynamicCache, EncoderDecoderCache
from transformers.integrations.sdpa_attention import repeat_kv, sdpa_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

if is_transformers_version(">=", "4.53"):
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, _ignore_causal_mask_sdpa, prepare_padding_mask
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock

if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel
Expand Down Expand Up @@ -218,14 +220,79 @@ def onnx_compatible_linalg_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None
return original_linal_norm(x, ord=ord, dim=dim, keepdim=keepdim, dtype=dtype, out=out)


def sdpa_mask_without_vmap(
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
batch_size: int,
cache_position: torch.Tensor,
kv_length: int,
kv_offset: int = 0,
attention_mask: Optional[torch.Tensor] = None,
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_torch_fix: bool = True,
**kwargs,
) -> Optional[torch.Tensor]:
q_length = cache_position.shape[0]
# Potentially pad the 2D mask, and slice it correctly
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)

# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
return None

# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset
reshaped_cache_position = cache_position.view(-1, 1)

# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
# the config through kwargs anyway, so it allows to rely on it
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
# but this is more efficient
sliding_window = getattr(kwargs["config"], "sliding_window", None)
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)

if sliding_window is not None and chunk_size is not None:
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")

# Simplest and most efficient way to obtain a causal mask
causal_mask = kv_arange <= reshaped_cache_position
# If using sliding window, add the sliding mask
if sliding_window is not None:
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
causal_mask *= sliding_mask_overlay
# If using chunk attention, add the chunked mask
elif chunk_size is not None:
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
causal_mask *= chunked_mask_overlay

causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
if padding_mask is not None:
causal_mask = causal_mask * padding_mask[:, None, None, :]

# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
if is_torch_version("<", "2.5") and allow_torch_fix:
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask


def eager_mask_without_vmap(*args, **kwargs) -> Optional[torch.Tensor]:
kwargs.pop("allow_torch_fix", None)
kwargs.pop("allow_is_causal_skip", None)
dtype = kwargs.get("dtype", torch.float32)
mask = sdpa_mask_without_vmap(*args, **kwargs, allow_is_causal_skip=False, allow_torch_fix=False)
mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), torch.finfo(dtype).min)
return mask


UNSUPPORTED_OPS_PATCHING_SPEC = [
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
PatchingSpec(torch.linalg, "norm", onnx_compatible_linalg_norm, original_linal_norm),
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave),
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
]
CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)]


class ModelPatcher:
Expand All @@ -239,7 +306,6 @@ def __init__(

patching_specs = config.PATCHING_SPECS or []
patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC)
patching_specs.extend(CACHE_PATCHING_SPEC)

self._patching_specs = []
for spec in patching_specs:
Expand Down Expand Up @@ -355,10 +421,25 @@ def __enter__(self):
self.patch_ops()
setattr(self._model, self.orig_forward_name, self.patched_forward)

self.original_cache_class = transformers.cache_utils.Cache
transformers.cache_utils.Cache = TraceableCache

if is_transformers_version(">=", "4.53"):
self.original_sdpa_mask = ALL_MASK_ATTENTION_FUNCTIONS["sdpa"]
self.original_eager_mask = ALL_MASK_ATTENTION_FUNCTIONS["eager"]
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", sdpa_mask_without_vmap)
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", eager_mask_without_vmap)

def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()
setattr(self._model, self.orig_forward_name, self.orig_forward)

transformers.cache_utils.Cache = self.original_cache_class

if is_transformers_version(">=", "4.53"):
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa", self.original_sdpa_mask)
ALL_MASK_ATTENTION_FUNCTIONS.register("eager", self.original_eager_mask)

def __call__(self, *args, **kwargs):
if getattr(self._model, self.orig_forward_name) is self.orig_forward:
logger.warning("Running the non-patched model")
Expand All @@ -368,14 +449,14 @@ def __call__(self, *args, **kwargs):
class Seq2SeqModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.48"):
if is_transformers_version(">=", "4.48") and is_transformers_version("<", "4.53"):
# this is required when gpt2 is used as decoder in any
# encoder-decoder model with cross attention blocks
ALL_ATTENTION_FUNCTIONS["sdpa"] = patched_sdpa_attention_forward

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.48"):
if is_transformers_version(">=", "4.48") and is_transformers_version("<", "4.53"):
ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward

def __init__(
Expand Down Expand Up @@ -680,43 +761,6 @@ def __init__(
self.build_alibi_tensor_original = transformers.models.falcon.modeling_falcon.build_alibi_tensor


class WavLMModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
model_kwargs = self.model_kwargs
# setting output_attentions=True in the model input to avoid calling torch.nn.functional.scaled_dot_product_attention
# in https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/wavlm/modeling_wavlm.py#L496
# that calls https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/functional.py#L5334
model_kwargs["output_attentions"] = True
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=model_kwargs)

outputs = self.orig_forward(*args, **kwargs)

filterd_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filterd_outputs[name] = value
return filterd_outputs

self.patched_forward = patched_forward


class MgpstrModelPatcher(ModelPatcher):
def __init__(
self,
Expand Down Expand Up @@ -1381,3 +1425,58 @@ def __init__(
model_kwargs["dataset_index"] = torch.tensor(0, device=model.device)

super().__init__(config, model, model_kwargs)


def qwen3_moe_forward_patched(self, hidden_states: torch.Tensor) -> torch.Tensor:
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# TODO: we loop over all possible experts to avoid issues in graph execution.
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits


class Qwen3MoeModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()

if is_transformers_version(">=", "4.53"):
self.original_moe_forward = Qwen3MoeSparseMoeBlock.forward
Qwen3MoeSparseMoeBlock.forward = qwen3_moe_forward_patched

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)

if is_transformers_version(">=", "4.53"):
Qwen3MoeSparseMoeBlock.forward = self.original_moe_forward
16 changes: 8 additions & 8 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
)

if TYPE_CHECKING:
from ..base import ExportConfig
from ..base import ExporterConfig

if is_torch_available():
from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -230,34 +230,34 @@ def get_diffusion_models_for_export(
pipeline: "DiffusionPipeline",
int_dtype: str = "int64",
float_dtype: str = "fp32",
) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "ExportConfig"]]:
) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "ExporterConfig"]]:
logger.warning(DEPRECATION_WARNING_GET_MODEL_FOR_EXPORT.format(model_type="diffusion"))
return _get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter="onnx")


def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"):
def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExporterConfig"):
logger.warning(DEPRECATION_WARNING_GET_MODEL_FOR_EXPORT.format(model_type="sam"))
return _get_sam_models_for_export(model, config)


def get_speecht5_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig", model_kwargs: Optional[Dict]
model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExporterConfig", model_kwargs: Optional[Dict]
):
logger.warning(DEPRECATION_WARNING_GET_MODEL_FOR_EXPORT.format(model_type="speecht5"))
return _get_speecht5_models_for_export(model, config)


def get_encoder_decoder_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"
) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "ExportConfig"]]:
model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExporterConfig"
) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "ExporterConfig"]]:
logger.warning(DEPRECATION_WARNING_GET_MODEL_FOR_EXPORT.format(model_type="encoder-decoder"))
return _get_encoder_decoder_models_for_export(model, config)


def get_decoder_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
config: "ExportConfig",
config: "ExporterConfig",
legacy: bool = False,
) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "ExportConfig"]]:
) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "ExporterConfig"]]:
logger.warning(DEPRECATION_WARNING_GET_MODEL_FOR_EXPORT.format(model_type="decoder"))
return _get_decoder_models_for_export(model, config, legacy)
Loading
Loading