Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d28dee8
test
IlyasMoutawwakil Jul 10, 2025
37dac0b
fix seq2seq patched sdpa
IlyasMoutawwakil Jul 10, 2025
19b3dc2
patch qwen3_moe, out_attentions, and eager_mask
IlyasMoutawwakil Jul 10, 2025
ce3809f
fix
IlyasMoutawwakil Jul 10, 2025
bd8dec6
use optimum model
IlyasMoutawwakil Jul 10, 2025
5593686
editable subpackages
IlyasMoutawwakil Jul 11, 2025
b10c340
Apply suggestions from code review
IlyasMoutawwakil Jul 15, 2025
d3bd103
smollm3 support
IlyasMoutawwakil Jul 15, 2025
f80c1c5
Merge branch 'transformers-4.53' of https://github.com/huggingface/op…
IlyasMoutawwakil Jul 15, 2025
6b33750
deprecate tensorflow onnx export and add smollm3 to export tests
IlyasMoutawwakil Jul 17, 2025
85521dc
write a more general sdpa_mask without vmap that's also vectorized an…
IlyasMoutawwakil Jul 21, 2025
dd6d0c1
better and more generic sdpa_mask_without_vmap implementation
IlyasMoutawwakil Jul 21, 2025
3dd85c1
style and fix
IlyasMoutawwakil Jul 21, 2025
f7b6ebd
fix
IlyasMoutawwakil Jul 21, 2025
c5e0165
patch find_packed_sequence_indices as it's untraceable
IlyasMoutawwakil Jul 21, 2025
3a4ea0d
fix
IlyasMoutawwakil Jul 21, 2025
74f064d
fix
IlyasMoutawwakil Jul 21, 2025
5976851
revert tests removal until refactor
IlyasMoutawwakil Jul 22, 2025
6c602ce
fix temporary hub repo import
IlyasMoutawwakil Jul 22, 2025
b43fead
fix
IlyasMoutawwakil Jul 22, 2025
9953b91
fix external data tests on windows
IlyasMoutawwakil Jul 22, 2025
310bcd1
update phi and phi3 min version
IlyasMoutawwakil Jul 22, 2025
459316d
condition modernbert optimization test
IlyasMoutawwakil Jul 22, 2025
4fc5972
get back old (pre 4.44) bloom modeling support and remove the need fo…
IlyasMoutawwakil Jul 22, 2025
4134e46
fix test was using hardcoded architecture
IlyasMoutawwakil Jul 22, 2025
2a6ef9c
unparallelize test that uses remote code
IlyasMoutawwakil Jul 22, 2025
8623353
support older versions of mpt and phi (4.36)
IlyasMoutawwakil Jul 22, 2025
a3ad9df
remove parallelism from slow tests
IlyasMoutawwakil Jul 22, 2025
77dd30f
fix vision to text pipelines test
IlyasMoutawwakil Jul 22, 2025
d41f0ea
more specific version handling for find_packed_sequence_indices
IlyasMoutawwakil Jul 23, 2025
7790092
fix
IlyasMoutawwakil Jul 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_exporters_onnx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ jobs:

- name: Test with pytest
run: |
pytest tests/exporters/onnx/test_export.py -vvvv --durations=0 -n auto
pytest tests/exporters/onnx/test_export.py -vvvv --durations=0
1 change: 0 additions & 1 deletion .github/workflows/test_exporters_tflite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ on:
branches: [main]
pull_request:
branches: [main]
types: [opened, synchronize, reopened, labeled, unlabeled]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/test_exporters_tflite_cli.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ on:
branches: [main]
pull_request:
branches: [main]
types: [opened, synchronize, reopened, labeled, unlabeled]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_onnxruntime_slow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ jobs:

- name: Test with pytest (in parallel)
run: |
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv -n auto
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv
env:
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
RUN_SLOW: 1
1 change: 1 addition & 0 deletions optimum/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
Expand Down
2 changes: 1 addition & 1 deletion optimum/commands/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from .base import ExportCommand
from .onnx import ONNXExportCommand
Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import onnx # noqa
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from .tasks import TasksManager # noqa
from .base import ExporterConfig # noqa
7 changes: 6 additions & 1 deletion optimum/exporters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@


class ExportConfig(ABC):
pass
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
logger.warning(
"The `ExportConfig` class is deprecated and will be removed in a future version. "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Please use `ExporterConfig` instead."
)


class ExporterConfig(ABC):
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ def main_export(
if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED and is_transformers_version("<", "4.42"):
loading_kwargs["attn_implementation"] = "eager"

# Only eager attention implementation returns attentions
if model_kwargs is not None and model_kwargs.get("output_attentions", False):
Comment thread
echarlaix marked this conversation as resolved.
loading_kwargs["attn_implementation"] = "eager"

with DisableCompileContextManager():
model = TasksManager.get_model_from_task(
task,
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,11 @@ def export_tensorflow(
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from
the ONNX configuration.
"""

logger.warning(
"The TensorFlow ONNX export is deprecated and will be removed in the next major release of Optimum."
)

# This is needed to import onnx and tf2onnx because onnx is also the name of the current directory.
import sys

Expand Down
57 changes: 35 additions & 22 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@
MgpstrModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
Qwen3MoeModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
SentenceTransformersTransformerPatcher,
SpeechT5ModelPatcher,
VisionEncoderDecoderPatcher,
VitPoseModelPatcher,
WavLMModelPatcher,
)


Expand Down Expand Up @@ -433,6 +433,11 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_tasks_manager_onnx("smollm3", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
class SmolLM3OnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")


@register_tasks_manager_onnx("olmo", *COMMON_TEXT_GENERATION_TASKS)
class OlmoOnnxConfig(LlamaOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
Expand All @@ -459,6 +464,7 @@ class Qwen3OnnxConfig(LlamaOnnxConfig):
)
class Qwen3MoeOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.51.0")
_MODEL_PATCHER = Qwen3MoeModelPatcher


@register_tasks_manager_onnx("gemma", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
Expand All @@ -476,19 +482,17 @@ class GraniteOnnxConfig(LlamaOnnxConfig):

@register_tasks_manager_onnx("phi", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = version.parse("4.42.0")
MIN_TRANSFORMERS_VERSION = version.parse("4.36.0")


@register_tasks_manager_onnx("phi3", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
class Phi3OnnxConfig(PhiOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
MistralDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA
MIN_TRANSFORMERS_VERSION = version.parse("4.50.0")
MIN_TRANSFORMERS_VERSION = version.parse("4.41.0")


@register_tasks_manager_onnx("internlm2", *["text-generation", "text-generation-with-past"])
Expand All @@ -499,7 +503,7 @@ class InternLM2OnnxConfig(LlamaOnnxConfig):
@register_tasks_manager_onnx("mistral", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
MIN_TRANSFORMERS_VERSION = version.parse("4.34.99")
MIN_TRANSFORMERS_VERSION = version.parse("4.35.0")

# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
Expand All @@ -511,12 +515,10 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
_MODEL_PATCHER = MistralModelPatcher


@register_tasks_manager_onnx("mpt", *["text-generation", "text-generation-with-past", "text-classification"])
@register_tasks_manager_onnx("mpt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"])
class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
DEFAULT_ONNX_OPSET = 13
# TODO: fix inference for transformers < v4.41 for beam_search > 1
MIN_TRANSFORMERS_VERSION = version.parse("4.41.0")
MIN_TRANSFORMERS_VERSION = version.parse("4.36.0")
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)
Expand All @@ -525,15 +527,30 @@ class MPTOnnxConfig(TextDecoderOnnxConfig):
@register_tasks_manager_onnx("bloom", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"])
class BloomOnnxConfig(TextDecoderOnnxConfig):
# Bloom does not require position_ids input.
DUMMY_INPUT_GENERATOR_CLASSES = (
BloomDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES

DEFAULT_ONNX_OPSET = 14 # Bloom uses F.scaled_dot_product_attention
MIN_TRANSFORMERS_VERSION = version.parse("4.44.0")
MIN_TRANSFORMERS_VERSION = version.parse("4.36.0")
DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, BloomDummyPastKeyValuesGenerator)
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if is_transformers_version(">=", "4.44"):
super().add_past_key_values(inputs_or_outputs, direction)
else:
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size * num_heads", 2: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size * num_heads", 1: decoder_sequence_name}


@register_tasks_manager_onnx(
"gpt_bigcode", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"]
Expand Down Expand Up @@ -1838,11 +1855,7 @@ class UniSpeechSATOnnxConfig(HubertOnnxConfig):
],
)
class WavLMOnnxConfig(HubertOnnxConfig):
DEFAULT_ONNX_OPSET = 12
# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
# due to the op torch.nn.functional.multi_head_attention_forward used for WavLM
_MODEL_PATCHER = WavLMModelPatcher
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


@register_tasks_manager_onnx("audio-spectrogram-transformer", *["feature-extraction", "audio-classification"])
Expand Down
Loading
Loading