Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Marian
- MarkupLM
- MBart
- MGP-STR
- Mistral
- MobileBert
- MobileVit
Expand Down
16 changes: 16 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from .model_patcher import (
CLIPModelPatcher,
FalconModelPatcher,
MgpstrModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
SAMModelPatcher,
Expand Down Expand Up @@ -895,6 +896,21 @@ def torch_to_onnx_input_map(self) -> Dict[str, str]:
return {"x": "pixel_values"}


class MgpstrOnnxConfig(ViTOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"char_logits": {0: "batch_size"},
"bpe_logits": {0: "batch_size"},
"wp_logits": {0: "batch_size"},
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MgpstrModelPatcher(self, model, model_kwargs=model_kwargs)


class SentenceTransformersTransformerOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DEFAULT_ONNX_OPSET = 14 # Some bottleneck transformers models require a specific ONNX opset to be successfully exported. We put a rather high opset here for the export to work for all architectures.
Expand Down
26 changes: 26 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,32 @@ def patched_forward(*args, **kwargs):
self.patched_forward = patched_forward


class MgpstrModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

# logits is a tuple, so we unpack it and return them as separate outputs
char_logits, bpe_logits, wp_logits = self.orig_forward(*args, **kwargs).logits
Comment thread
xenova marked this conversation as resolved.

return {
"char_logits": char_logits,
"bpe_logits": bpe_logits,
"wp_logits": wp_logits,
}

self.patched_forward = patched_forward


class SAMModelPatcher(ModelPatcher):
def __init__(
self,
Expand Down
7 changes: 6 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class TasksManager:
"image-classification": "AutoModelForImageClassification",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": "AutoModelForVision2Seq",
"image-to-text": ("AutoModelForVision2Seq", "AutoModel"),
"mask-generation": "AutoModel",
"masked-im": "AutoModelForMaskedImageModeling",
"multiple-choice": "AutoModelForMultipleChoice",
Expand Down Expand Up @@ -797,6 +797,11 @@ class TasksManager:
"question-answering",
onnx="MBartOnnxConfig",
),
"mgp-str": supported_tasks_mapping(
"feature-extraction",
"image-to-text",
onnx="MgpstrOnnxConfig",
),
"mistral": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"markuplm": "hf-internal-testing/tiny-random-MarkupLMModel",
"mbart": "hf-internal-testing/tiny-random-mbart",
"mgp-str": "hf-internal-testing/tiny-random-MgpstrForSceneTextRecognition",
"mistral": "echarlaix/tiny-random-mistral",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
Expand Down Expand Up @@ -245,6 +246,7 @@
"marian": "Helsinki-NLP/opus-mt-en-de",
"markuplm": "hf-internal-testing/tiny-random-MarkupLMModel",
"mbart": "sshleifer/tiny-mbart",
"mgp-str": "alibaba-damo/mgp-str-base",
"mobilebert": "google/mobilebert-uncased",
# "mobilenet_v1": "google/mobilenet_v1_0.75_192",
# "mobilenet_v2": "google/mobilenet_v2_0.35_96",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "echarlaix/tiny-random-marian",
"mbart": "hf-internal-testing/tiny-random-mbart",
"mgp-str": "hf-internal-testing/tiny-random-MgpstrForSceneTextRecognition",
"mistral": "echarlaix/tiny-random-mistral",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
Expand Down