Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class ORTModel(ORTSessionMixin, OptimizedModel):

model_type = "onnx_model"
auto_model_class = AutoModel
_library_name: Optional[str] = None

def __init__(
self,
Expand Down Expand Up @@ -431,6 +432,7 @@ def _export(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
library_name=cls._library_name,
)
maybe_save_preprocessors(model_id, model_save_path, src_subfolder=subfolder)

Expand Down Expand Up @@ -628,6 +630,7 @@ class ORTModelForFeatureExtraction(ORTModel):
"""

auto_model_class = AutoModel
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -753,6 +756,7 @@ class ORTModelForMaskedLM(ORTModel):
"""

auto_model_class = AutoModelForMaskedLM
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -855,6 +859,7 @@ class ORTModelForQuestionAnswering(ORTModel):
"""

auto_model_class = AutoModelForQuestionAnswering
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -974,6 +979,7 @@ class ORTModelForSequenceClassification(ORTModel):
"""

auto_model_class = AutoModelForSequenceClassification
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -1077,6 +1083,7 @@ class ORTModelForTokenClassification(ORTModel):
"""

auto_model_class = AutoModelForTokenClassification
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -1173,6 +1180,7 @@ class ORTModelForMultipleChoice(ORTModel):
"""

auto_model_class = AutoModelForMultipleChoice
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -1278,6 +1286,7 @@ class ORTModelForImageClassification(ORTModel):
"""

auto_model_class = AutoModelForImageClassification
_library_name: Optional[str] = "transformers"
Comment thread
echarlaix marked this conversation as resolved.
Outdated

@add_start_docstrings_to_model_forward(
ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
Expand Down Expand Up @@ -1376,6 +1385,7 @@ class ORTModelForSemanticSegmentation(ORTModel):
"""

auto_model_class = AutoModelForSemanticSegmentation
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
Expand Down Expand Up @@ -1479,6 +1489,7 @@ class ORTModelForAudioClassification(ORTModel):
"""

auto_model_class = AutoModelForAudioClassification
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -1577,6 +1588,7 @@ class ORTModelForCTC(ORTModel):
"""

auto_model_class = AutoModelForCTC
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -1681,6 +1693,7 @@ class ORTModelForAudioXVector(ORTModel):
"""

auto_model_class = AutoModelForAudioXVector
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -1770,6 +1783,7 @@ class ORTModelForAudioFrameClassification(ORTModel):
"""

auto_model_class = AutoModelForAudioFrameClassification
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
Expand Down Expand Up @@ -1850,6 +1864,7 @@ class ORTModelForImageToImage(ORTModel):
"""

auto_model_class = AutoModelForImageToImage
_library_name: Optional[str] = "transformers"

@add_start_docstrings_to_model_forward(
ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,7 @@ def _export(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
library_name=cls._library_name,
)
maybe_save_preprocessors(model_id, model_save_path, src_subfolder=subfolder)

Expand Down
15 changes: 15 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,21 @@ def test_compare_to_io_binding(self, model_arch):

gc.collect()

def test_load_sentence_transformers_model_as_fill_mask(self):
model_id = "sparse-encoder-testing/splade-bert-tiny-nq"
onnx_model = ORTModelForMaskedLM.from_pretrained(model_id)
tokenizer = get_preprocessor(model_id)
MASK_TOKEN = tokenizer.mask_token
pipe = pipeline("fill-mask", model=onnx_model, tokenizer=tokenizer, device=0)
text = f"The capital of France is {MASK_TOKEN}."
outputs = pipe(text)

self.assertEqual(pipe.device, onnx_model.device)
self.assertGreaterEqual(outputs[0]["score"], 0.0)
self.assertIsInstance(outputs[0]["token_str"], str)

gc.collect()


class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
Expand Down
Loading