Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
63a6efe
test left-padded batched inference
IlyasMoutawwakil Jul 24, 2025
39496d8
demonstrate batched tex generation failure
IlyasMoutawwakil Jul 24, 2025
2ccc150
fix remote code
IlyasMoutawwakil Jul 24, 2025
ecf65d5
fix
IlyasMoutawwakil Jul 24, 2025
9f3eedc
fix position_ids generation inside ORTModelForCausalLM class
IlyasMoutawwakil Jul 24, 2025
b7bec5e
it works until transformers 4.52 -_-
IlyasMoutawwakil Jul 25, 2025
0df42e5
now run with latest transformers
IlyasMoutawwakil Jul 25, 2025
999a145
bolean 4D mask is actually not supported by torch onnx exporter
IlyasMoutawwakil Jul 25, 2025
638856e
only test generation with batched inputs, for logits are a bit off be…
IlyasMoutawwakil Jul 25, 2025
3d40502
boolean mask safe softmax batched inference
IlyasMoutawwakil Jul 25, 2025
023d2ac
style
IlyasMoutawwakil Jul 25, 2025
accf852
use old typing
IlyasMoutawwakil Jul 25, 2025
0965ea9
don't do unnecessary patching
IlyasMoutawwakil Jul 25, 2025
d1f9bbd
try to avoid spamming the hub for an image
IlyasMoutawwakil Jul 25, 2025
01c4084
update min transformers version
IlyasMoutawwakil Jul 27, 2025
aeeecb2
better and direct torch patching
IlyasMoutawwakil Jul 28, 2025
fc62f42
more batched generation special cases
IlyasMoutawwakil Jul 28, 2025
ba994fb
style
IlyasMoutawwakil Jul 28, 2025
de6a798
initialize the il image instead of downloading it
IlyasMoutawwakil Jul 28, 2025
cf164b3
use random pil image
IlyasMoutawwakil Jul 28, 2025
5934bf9
test different versions of transformers in fast tests
IlyasMoutawwakil Jul 28, 2025
4b76f5e
fix
IlyasMoutawwakil Jul 28, 2025
e171196
revert diffusers changes for now
IlyasMoutawwakil Jul 28, 2025
5ab88b6
mask padding kv cache as well
IlyasMoutawwakil Jul 28, 2025
4d35600
fix masking for old bloom
IlyasMoutawwakil Jul 28, 2025
b2a5f41
use constant image to image loading errors
IlyasMoutawwakil Jul 28, 2025
3f58892
style
IlyasMoutawwakil Jul 28, 2025
b9d2e03
test diffusers in series to avoid runner dying
IlyasMoutawwakil Jul 28, 2025
bdcc425
fix
IlyasMoutawwakil Jul 28, 2025
a3dc4e8
cleanup and some comments
IlyasMoutawwakil Jul 28, 2025
a1ff2f2
fix and test falcon alibi
IlyasMoutawwakil Jul 28, 2025
603f62c
style
IlyasMoutawwakil Jul 28, 2025
cf5b562
fix, support and test multi_query=False as well
IlyasMoutawwakil Jul 28, 2025
3a29549
only apply masked testing for transformers version previous to 4.39
IlyasMoutawwakil Jul 28, 2025
af5fa34
Update optimum/onnxruntime/modeling_decoder.py
IlyasMoutawwakil Jul 28, 2025
59c0c14
use text decoder position ids onnx config but test its sync with list
IlyasMoutawwakil Jul 29, 2025
b5d92e5
Merge branch 'fix-ort-batched-generation' of https://github.com/huggi…
IlyasMoutawwakil Jul 29, 2025
9db07bf
fix opt
IlyasMoutawwakil Jul 29, 2025
98123d4
style
IlyasMoutawwakil Jul 29, 2025
411df8f
fix sdpa without overriting torch onnx exporter
IlyasMoutawwakil Jul 29, 2025
133f340
use inplace op ;-;
IlyasMoutawwakil Jul 29, 2025
9044948
Merge branch 'main' into fix-ort-batched-generation
IlyasMoutawwakil Jul 30, 2025
c98ab28
fix st test
IlyasMoutawwakil Jul 30, 2025
e787b92
patch directly in onnx because patch needs to happen after softmax
IlyasMoutawwakil Jul 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
matrix:
python-version: [3.9]
runs-on: [ubuntu-22.04]
transformers_version: [latest, 4.36.*, 4.45.*]
test_file:
[
test_timm.py,
Expand Down Expand Up @@ -59,13 +60,26 @@ jobs:
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[tests,onnxruntime] diffusers

- name: Test with pytest (in series)
if: matrix.test_file == 'test_modeling.py'
- name: Install transformers ${{ matrix.transformers-version }}
run: |
pytest tests/onnxruntime/test_modeling.py -m "run_in_series" --durations=0 -vvvv
if [ "${{ matrix.transformers_version }}" == '4.36.*' ]; then
pip install "transformers==4.36.*" "diffusers<0.32.0"
elif [ "${{ matrix.transformers_version }}" == '4.45.*' ]; then
pip install "transformers==4.45.*" "diffusers<0.33.0"
else
pip install transformers;
fi

- name: Test with pytest (in parallel)
if: matrix.test_file != 'test_diffusion.py'
run: |
pytest tests/onnxruntime/${{ matrix.test_file }} --durations=0 -vvvv -n auto
env:
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}

- name: Test with pytest (in series)
if: matrix.test_file == 'test_diffusion.py'
run: |
pytest tests/onnxruntime/${{ matrix.test_file }} -m "not run_in_series" --durations=0 -vvvv -n auto
pytest tests/onnxruntime/${{ matrix.test_file }} --durations=0 -vvvv
env:
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
17 changes: 3 additions & 14 deletions .github/workflows/test_onnxruntime_slow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ jobs:
python-version: [3.9]
transformers-version: [latest]
runs-on: [ubuntu-22.04, windows-2022]
include:
- {python-version: 3.9, transformers-version: 4.36.*, runs-on: ubuntu-22.04}
- {python-version: 3.9, transformers-version: 4.45.*, runs-on: ubuntu-22.04}

runs-on: ${{ matrix.runs-on }}

steps:
- name: Free Disk Space (Ubuntu)
if: matrix.runs-on == 'ubuntu-22.04'
uses: jlumbroso/free-disk-space@main
with:
swap-storage: false

- name: Free Disk Space (macOS)
if: matrix.runs-on == 'macos-15'
Expand All @@ -69,22 +68,12 @@ jobs:
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[tests,onnxruntime] diffusers

- name: Install transformers ${{ matrix.transformers-version }}
if: ${{ matrix.transformers-version == '4.36.*' }}
run: |
pip install "transformers==${{ matrix.transformers-version }}" "diffusers<0.32.0"

- name: Install transformers ${{ matrix.transformers-version }}
if: ${{ matrix.transformers-version == '4.45.*' }}
run: |
pip install "transformers==${{ matrix.transformers-version }}" "diffusers<0.33.0"

- name: Test with pytest (in series)
run: |
pytest tests/onnxruntime -m "run_in_series" --durations=0 -vvvv
env:
RUN_SLOW: 1

- name: Test with pytest (in parallel)
run: |
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv
Expand Down
3 changes: 2 additions & 1 deletion optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,14 @@ def __init__(
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"}
self.add_past_key_values(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was wrong

else:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}

return common_inputs

@property
Expand Down
74 changes: 20 additions & 54 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
CLIPModelPatcher,
FalconModelPatcher,
MgpstrModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
Qwen3MoeModelPatcher,
SAMModelPatcher,
Expand Down Expand Up @@ -409,20 +407,12 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


# OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46
if is_transformers_version(">=", "4.46.0"):

@register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"])
class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

else:

@register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"])
class OPTOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
@register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"])
class OPTOnnxConfig(
TextDecoderWithPositionIdsOnnxConfig if is_transformers_version(">=", "4.46.0") else TextDecoderOnnxConfig
):
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_tasks_manager_onnx("llama", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
Expand Down Expand Up @@ -477,7 +467,6 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
@register_tasks_manager_onnx("granite", *COMMON_TEXT_GENERATION_TASKS)
class GraniteOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")
MIN_TORCH_VERSION = version.parse("2.5.0")


@register_tasks_manager_onnx("phi", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
Expand All @@ -502,17 +491,11 @@ class InternLM2OnnxConfig(LlamaOnnxConfig):

@register_tasks_manager_onnx("mistral", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"])
class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
MIN_TRANSFORMERS_VERSION = version.parse("4.35.0")

# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
DUMMY_INPUT_GENERATOR_CLASSES = (
MistralDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)
_MODEL_PATCHER = MistralModelPatcher


@register_tasks_manager_onnx("mpt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"])
Expand Down Expand Up @@ -556,9 +539,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
"gpt_bigcode", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"]
)
class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
GPTBigCodeDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GPTBigCodeDummyPastKeyValuesGenerator)
DEFAULT_ONNX_OPSET = 14 # GPT BigCode now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("gpt_bigcode")
Expand All @@ -571,36 +552,29 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
decoder_sequence_name = "past_sequence_length + sequence_length"
name = "present"

for i in range(self._normalized_config.num_layers):
# No dim for `n_head` when using multi-query attention
inputs_or_outputs[f"{name}.{i}.key_value"] = {
0: "batch_size",
1: decoder_sequence_name,
}
if self._normalized_config.multi_query:
# No dim for `n_head` when using multi-query attention
inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 1: decoder_sequence_name}
else:
inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 2: decoder_sequence_name}
Comment on lines +559 to +563
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

support for multi_query=True/False for gpt bigcode


def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key_value"] = t


@register_tasks_manager_onnx("falcon", *COMMON_TEXT_GENERATION_TASKS + ["question-answering", "token-classification"])
class FalconOnnxConfig(TextDecoderOnnxConfig):
# This is due to the cache refactoring for Falcon in 4.36
MIN_TRANSFORMERS_VERSION = version.parse("4.35.99")
class FalconOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.36.0")

DUMMY_INPUT_GENERATOR_CLASSES = (
FalconDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, FalconDummyPastKeyValuesGenerator)
DEFAULT_ONNX_OPSET = 14 # Falcon uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_PKV_GENERATOR_CLASS = FalconDummyPastKeyValuesGenerator

# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
_MODEL_PATCHER = FalconModelPatcher

def __init__(
self,
config: "PretrainedConfig",
Expand Down Expand Up @@ -634,10 +608,8 @@ def __init__(
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs

if not self.legacy and not self._config.alibi and self.task in ["text-generation", "feature-extraction"]:
# When alibi is used, position_ids are not used in Falcon.
# Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
if self._config.alibi:
common_inputs.pop("position_ids", None)
Comment thread
echarlaix marked this conversation as resolved.

return common_inputs

Expand Down Expand Up @@ -836,7 +808,6 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
)
class BartOnnxConfig(M2M100OnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Bart now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
MIN_TORCH_VERSION = version.parse("2.1.2")


@register_tasks_manager_onnx(
Expand Down Expand Up @@ -868,7 +839,7 @@ class BigBirdPegasusOnnxConfig(BartOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
inputs = super().inputs
if self._config.attention_type == "block_sparse":
if self._config.attention_type == "block_sparse" and self.task != "text-generation":
# BigBirdPegasusEncoder creates its own attention_mask internally
# https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py#L1875
inputs.pop("attention_mask", None)
Expand All @@ -888,7 +859,6 @@ class MarianOnnxConfig(BartOnnxConfig):
@register_tasks_manager_onnx("vit", *["feature-extraction", "image-classification", "masked-im"])
class ViTOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
MIN_TORCH_VERSION = version.parse("1.11")
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

@property
Expand Down Expand Up @@ -1574,7 +1544,6 @@ class OwlViTOnnxConfig(CLIPOnnxConfig):
# Sets the absolute tolerance to when validating the exported ONNX model against the
# reference model.
ATOL_FOR_VALIDATION = 1e-4
MIN_TORCH_VERSION = version.parse("2.1")

# needs einsum operator support, available since opset 12
DEFAULT_ONNX_OPSET = 12
Expand Down Expand Up @@ -1646,7 +1615,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
"layoutlmv3", *["feature-extraction", "question-answering", "text-classification", "token-classification"]
)
class LayoutLMv3OnnxConfig(TextAndVisionOnnxConfig):
MIN_TORCH_VERSION = version.parse("1.12")
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
allow_new=True,
MAX_2D_POSITION_EMBEDDINGS="max_2d_position_embeddings",
Expand Down Expand Up @@ -2570,8 +2538,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
@register_tasks_manager_onnx("sam", *["feature-extraction"])
class SamOnnxConfig(OnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.29.0.dev0")
# Since ransformers 4.32.0, SAM uses repeat_interleave op that is broken in PyTorch 2.0.1: https://github.com/pytorch/pytorch/issues/100429
MIN_TORCH_VERSION = version.parse("2.0.99")
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator, DummyVisionEmbeddingsGenerator)
DEFAULT_ONNX_OPSET = 13 # Opset 12 for repeat_interleave falls back on the opset 9 implem, that raises Unsupported: ONNX export of repeat_interleave in opset 9.
Expand Down
Loading
Loading