Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2cc29a0
added test_decoders.py
IlyasMoutawwakil May 23, 2025
4da3df2
fix position ids for single batch and more complete decoder testing f…
IlyasMoutawwakil May 23, 2025
d208f6b
support merging seq2seq models when used as decoders and add more tests
IlyasMoutawwakil May 23, 2025
608da13
fix pipe tests
IlyasMoutawwakil May 23, 2025
f34f6e7
update phi min transformers version (broken by cache position refacto…
IlyasMoutawwakil May 23, 2025
2e1a700
remove deprecated bloom modeling
IlyasMoutawwakil May 23, 2025
00ec0c7
update opt onnx config to the one with position ids
IlyasMoutawwakil May 23, 2025
88dc4a8
remove all complex deprecated modeling
IlyasMoutawwakil May 23, 2025
5f9419e
get_supported_model_type_for_task should only return suooprted model …
IlyasMoutawwakil May 23, 2025
478fd57
update min transformers
IlyasMoutawwakil May 24, 2025
6aa3a17
use transformers like api for use_cache and add can_use_cache and is_…
IlyasMoutawwakil May 24, 2025
7da7015
testing
IlyasMoutawwakil May 24, 2025
f9f7395
fix
IlyasMoutawwakil May 25, 2025
8785be6
fix
IlyasMoutawwakil May 25, 2025
5f81515
remove unnecessary
IlyasMoutawwakil May 25, 2025
6e3bff1
simply qwen3
IlyasMoutawwakil May 25, 2025
aacf172
docs
IlyasMoutawwakil May 25, 2025
2b0137f
qwen-moe
IlyasMoutawwakil May 25, 2025
7041a89
model type shenanigans
IlyasMoutawwakil May 25, 2025
de244d5
fix
IlyasMoutawwakil May 25, 2025
088b265
use test models from optimum-internal-hf with proper metadata
IlyasMoutawwakil May 26, 2025
af7c6bb
Update optimum/onnxruntime/modeling_decoder.py
IlyasMoutawwakil May 27, 2025
0149349
keep supported model types
IlyasMoutawwakil May 27, 2025
2d9d7ea
Merge branch 'distribute-tests' of https://github.com/huggingface/opt…
IlyasMoutawwakil May 27, 2025
fbcf3a1
Merge branch 'main' into distribute-tests
IlyasMoutawwakil May 27, 2025
2cf507c
optimum model
IlyasMoutawwakil May 27, 2025
8821d97
fix failing test by forcing export
IlyasMoutawwakil May 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ jobs:
matrix:
python-version: [3.9]
runs-on: [ubuntu-22.04]
test_file: [
test_file:
[
test_timm.py,
test_modeling.py, # todo: split into test_encoder, test_decoder and test_encoder_decoder
test_decoder.py,
test_modeling.py,
test_diffusion.py,
test_optimization.py,
test_quantization.py,
Expand Down
18 changes: 9 additions & 9 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,19 +938,19 @@ def post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)

# Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task
if len(onnx_files_subpaths) >= 3 and self.use_past is True:
decoder_path = Path(path, onnx_files_subpaths[1])
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
# Attempt to merge only if the decoder was exported without/with past
onnx_decoder_path = Path(path, ONNX_DECODER_NAME + ".onnx")
onnx_decoder_with_past_path = Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx")
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
if onnx_decoder_path.is_file() and onnx_decoder_with_past_path.is_file() and self.use_past is True:
try:
from ...onnx import merge_decoders

# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
from ...onnx import merge_decoders

merge_decoders(
decoder=decoder_path,
decoder_with_past=decoder_with_past_path,
decoder=onnx_decoder_path,
decoder_with_past=onnx_decoder_with_past_path,
save_path=decoder_merged_path,
strict=False,
)
Expand Down
33 changes: 5 additions & 28 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):


# OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46
if is_transformers_version(">=", "4.45.99"):
if is_transformers_version(">=", "4.46.0"):

class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
Expand All @@ -352,7 +352,6 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):

class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Llama now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
Expand Down Expand Up @@ -385,7 +384,7 @@ class GraniteOnnxConfig(LlamaOnnxConfig):
class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
MIN_TRANSFORMERS_VERSION = version.parse("4.36.0")
MIN_TRANSFORMERS_VERSION = version.parse("4.42.0")


class Phi3OnnxConfig(PhiOnnxConfig):
Expand Down Expand Up @@ -430,33 +429,11 @@ class BloomOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
BloomDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES

DEFAULT_ONNX_OPSET = 14 # Bloom uses F.scaled_dot_product_attention
MIN_TRANSFORMERS_VERSION = version.parse("4.44.0")
DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")
DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if is_transformers_version(">=", "4.44"):
super().add_past_key_values(inputs_or_outputs, direction)
else:
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
}


if is_transformers_version(">=", "4.45.99"):
if is_transformers_version(">=", "4.46.0"):
MODEL_TYPES_REQUIRING_POSITION_IDS.add("opt")


Expand Down
15 changes: 13 additions & 2 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from transformers import AutoConfig, PretrainedConfig, is_tf_available, is_torch_available
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_NAME, WEIGHTS_NAME, http_user_agent, logging

from ..utils.import_utils import is_diffusers_available, is_onnx_available
from ..utils.import_utils import is_diffusers_available, is_onnx_available, is_transformers_version


if TYPE_CHECKING:
Expand Down Expand Up @@ -1475,12 +1475,23 @@ def get_supported_model_type_for_task(task: str, exporter: str) -> List[str]:
"""
Returns the list of supported architectures by the exporter for a given task. Transformers-specific.
"""
return [

supported_model_types = [
model_type.replace("-", "_")
for model_type in TasksManager._SUPPORTED_MODEL_TYPE
if task in TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter]
and is_transformers_version(
">=",
str(
TasksManager.get_exporter_config_constructor(
exporter, task=task, model_type=model_type
).func.MIN_TRANSFORMERS_VERSION
),
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking it might be easier to raise an error instead before export when the transformers version is not compatible and to keep all supported architectures so that users know that the architecture export is supported but that transformers needs to be upgraded, what do you think @IlyasMoutawwakil ?

f"{config.MIN_TRANSFORMERS_VERSION}, got: {transformers.__version__}"

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact I only added this here to be able to use it in test_find_untested_architectures
I can move the version checks there and keep this method as is.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done ! I simply remove them the unsupported models (because version) using CONFIG_MAPPING_NAMES

        supported_transformers_models = set(CONFIG_MAPPING_NAMES.keys())
        supported_export_models = set(TasksManager.get_supported_model_type_for_task(task=self.TASK, exporter="onnx"))
        supported_export_models = supported_export_models & supported_transformers_models
        untested_models = supported_export_models - tested_models

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the raising of version error during export, I thin that's already the case

]

return supported_model_types

@staticmethod
def synonyms_for_task(task: str) -> Set[str]:
synonyms = [k for k, v in TasksManager._SYNONYM_TASK_MAP.items() if v == task]
Expand Down
Loading