Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,7 @@ class Phi3OnnxConfig(PhiOnnxConfig):
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA
MIN_TRANSFORMERS_VERSION = version.parse("4.41.0")

def __init__(self, *args, **kwargs):
if is_transformers_version("==", "4.46.0"):
logger.error(
"Found transformers v4.46.0 while trying to exporting a Phi3 model, this specific version of transformers is not supported. "
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for 4.46 fixed in : huggingface/transformers#34376
for 4.49 fixed in : huggingface/transformers#36332
currently failing with v4.45 for num_beams > 1 : https://github.com/huggingface/optimum/actions/runs/14990250710

upgrading min transformers needed to v4.50 cc @IlyasMoutawwakil

"Please upgrade to v4.46.1 or higher, or downgrade your transformers version"
)
super().__init__(*args, **kwargs)
MIN_TRANSFORMERS_VERSION = version.parse("4.50.0")


class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand Down
5 changes: 4 additions & 1 deletion tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@
self.assertEqual(model.unet.device, torch.device("cuda:0"))
self.assertEqual(model.text_encoder.device, torch.device("cuda:0"))
self.assertEqual(model.vae_decoder.device, torch.device("cuda:0"))
self.assertEqual(model.vae_encoder.device, torch.device("cuda:0"))

Check warning on line 943 in tests/onnxruntime/test_modeling.py

View workflow job for this annotation

GitHub Actions / trufflehog

Found unverified HuggingFace result 🐷🔑
self.assertEqual(model.unet.session.get_providers()[0], "ROCMExecutionProvider")
self.assertEqual(model.text_encoder.session.get_providers()[0], "ROCMExecutionProvider")
self.assertEqual(model.vae_decoder.session.get_providers()[0], "ROCMExecutionProvider")
Expand Down Expand Up @@ -2549,11 +2549,14 @@

# TODO: fix "mpt" for which inference fails for transformers < v4.41
if is_transformers_version(">=", "4.41"):
SUPPORTED_ARCHITECTURES.extend(["phi3", "mpt"])
SUPPORTED_ARCHITECTURES.append("mpt")

if is_transformers_version(">=", "4.45"):
SUPPORTED_ARCHITECTURES.append("granite")

if is_transformers_version(">=", "4.50"):
SUPPORTED_ARCHITECTURES.append("phi3")

FULL_GRID = {
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [False, True],
Expand Down
Loading