Skip to content
Merged
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
pass


class GraniteOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")
MIN_TORCH_VERSION = version.parse("2.5.0")


class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"phi",
"phi3",
"qwen2",
"granite",
}


Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,13 @@ class TasksManager:
"text-classification",
onnx="LlamaOnnxConfig",
),
"granite": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
onnx="GraniteOnnxConfig",
),
"pegasus": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def prepare_past_key_values(
if self.model_type == "gemma":
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.head_dim
elif self.model_type in {"mistral", "llama", "qwen2"}:
elif self.model_type in {"mistral", "llama", "qwen2", "granite"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ORTConfigManager:
"gpt-neo": "gpt2",
"gpt-neox": "gpt2",
"gptj": "gpt2",
"granite": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"llama": "gpt2",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class NormalizedConfigManager:
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"qwen2": NormalizedTextConfig,
"granite": NormalizedTextConfigWithGQA,
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
"gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
"imagegpt": "hf-internal-testing/tiny-random-ImageGPTModel",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,6 +2324,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"gpt_neo",
"gpt_neox",
"gptj",
"granite",
"llama",
"mistral",
"mpt",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM",
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
Expand Down