Skip to content

ONNX Export Support for jinaai/jina-embeddings-v5-text-nano-retrieval #3666

@contrebande-labs

Description

@contrebande-labs

I'm using the git main version of both transformers and sentence transformers as of 30 minutes ago.

I'm unable to load the model in sentence-transformers and save it in ONNX with the Pytorch exporter.

When I load the model I get these warnings:

Unrecognized keys in `rope_parameters` for 'rope_type'='default': {'factor'}
--- Logging error ---
Traceback (most recent call last):
  File "/usr/lib/python3.12/logging/__init__.py", line 1160, in emit
    msg = self.format(record)
          ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/logging/__init__.py", line 999, in format
    return fmt.format(record)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/logging/__init__.py", line 703, in format
    record.message = record.getMessage()
                     ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/logging/__init__.py", line 392, in getMessage
    msg = msg % self.args
          ~~~~^~~~~~~~~~~
TypeError: not all arguments converted during string formatting
Call stack:
  File "/home/user/Documents/git/osmatique-machine/ir-sts/onnx-export.py", line 357, in <module>
    st_pytorch_model = SentenceTransformer(st_torch_input_model_path, backend="torch", device=torch_device, trust_remote_code=True, local_files_only=False, model_kwargs={"dtype": torch.float32}).eval()
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/sentence_transformers/SentenceTransformer.py", line 327, in __init__
    modules, self.module_kwargs = self._load_sbert_model(
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/sentence_transformers/SentenceTransformer.py", line 2305, in _load_sbert_model
    module = module_class.load(
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/sentence_transformers/models/Transformer.py", line 436, in load
    return cls(model_name_or_path=model_name_or_path, **init_kwargs)
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/sentence_transformers/models/Transformer.py", line 121, in __init__
    self._load_model(model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args)
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/sentence_transformers/models/Transformer.py", line 270, in _load_model
    self.auto_model = AutoModel.from_pretrained(
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py", line 367, in from_pretrained
    return model_class.from_pretrained(
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4048, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/home/user/.cache/huggingface/modules/transformers_modules/jina_hyphen_embeddings_hyphen_v5_hyphen_text_hyphen_nano_hyphen_retrieval/modeling_eurobert.py", line 523, in __init__
    self.mask_converter = AttentionMaskConverter(is_causal=False)
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/transformers/modeling_attn_mask_utils.py", line 74, in __init__
    logger.warning_once(DEPRECATION_MESSAGE, FutureWarning)
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/transformers/utils/logging.py", line 327, in warning_once
    self.warning(*args, **kwargs)
Message: 'The attention mask API under `transformers.modeling_attn_mask_utils` (`AttentionMaskConverter`) is deprecated and will be removed in Transformers v5.10. Please use the new API in `transformers.masking_utils`.'
Arguments: (<class 'FutureWarning'>,)
Loading weights: 100%|██████████████████████████████████████████████████████████████| 110/110 [00:03<00:00, 31.33it/s, Materializing param=norm.weight]
Unrecognized keys in `rope_parameters` for 'rope_type'='default': {'factor'}
Loaded pytorch model.

And when I try to export to ONNX, I get this error:

W0219 11:47:09.061000 25308 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'input' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0, sampling_ratio: 'int' = -1, aligned: 'bool' = False). Treating as an Input.
W0219 11:47:09.061000 25308 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'boxes' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0, sampling_ratio: 'int' = -1, aligned: 'bool' = False). Treating as an Input.
W0219 11:47:09.061000 25308 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'input' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0). Treating as an Input.
W0219 11:47:09.062000 25308 torch/onnx/_internal/exporter/_schemas.py:455] Missing annotation for parameter 'boxes' from (input, boxes, output_size: 'Sequence[int]', spatial_scale: 'float' = 1.0). Treating as an Input.
[torch.onnx] Obtain model graph for `SentenceTransformer([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `SentenceTransformer([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ❌
Traceback (most recent call last):
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_core.py", line 610, in _handle_call_function_node_with_lowering
    outputs = onnx_function(*onnx_args, **onnx_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/onnxscript/_internal/values.py", line 476, in __call__
    return self.func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/onnxscript/function_libs/torch_lib/ops/nn.py", line 1839, in aten_scaled_dot_product_attention
    key, value = _attention_repeat_kv_for_group_query(query, key, value)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/onnxscript/function_libs/torch_lib/ops/nn.py", line 1756, in _attention_repeat_kv_for_group_query
    query.shape[1] > key.shape[1] == value.shape[1] and query.shape[1] % key.shape[1] == 0
AssertionError: SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_core.py", line 815, in _translate_fx_graph
    _handle_call_function_node_with_lowering(
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_core.py", line 612, in _handle_call_function_node_with_lowering
    raise _errors.GraphConstructionError(
torch.onnx._internal.exporter._errors.GraphConstructionError: Error when calling function 'TracedOnnxFunction(<function aten_scaled_dot_product_attention at 0x72c2e1a68860>)' with args '[SymbolicTensor(name='add_173', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s28), 12, SymbolicDim(s11), 64]), producer='node_add_173', index=0), SymbolicTensor(name='add_209', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s28), 12, SymbolicDim(s11), 64]), producer='node_add_209', index=0), SymbolicTensor(name='transpose_3', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s28), 12, SymbolicDim(s11), 64]), producer='node_transpose_3', index=0), SymbolicTensor(name='masked_fill', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s28), 1, SymbolicDim(s11), SymbolicDim(s11)]), producer='node_masked_fill', index=0)]' and kwargs '{'scale': 0.125, 'enable_gqa': True}'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_core.py", line 1489, in export
    onnx_program = _exported_program_to_onnx_program(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_core.py", line 1121, in _exported_program_to_onnx_program
    values = _translate_fx_graph(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_core.py", line 841, in _translate_fx_graph
    raise _errors.ConversionError(
torch.onnx._internal.exporter._errors.ConversionError: Error when translating node %scaled_dot_product_attention : [num_users=1] = call_function[target=torch.ops.aten.scaled_dot_product_attention.default](args = (%add_173, %add_209, %transpose_3, %masked_fill), kwargs = {scale: 0.125, enable_gqa: True}). See the stack trace for more information.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/Documents/git/osmatique-machine/ir-sts/onnx-export.py", line 391, in <module>
    st_onnx_program = torch.onnx.export(
                      ^^^^^^^^^^^^^^^^^^
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/__init__.py", line 296, in export
    return _compat.export_compat(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_compat.py", line 154, in export_compat
    onnx_program = _core.export(
                   ^^^^^^^^^^^^^
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_flags.py", line 27, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/venv/sentence-transformers-onnx/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_core.py", line 1535, in export
    raise _errors.ConversionError(
torch.onnx._internal.exporter._errors.ConversionError: Failed to convert the exported program to an ONNX model. This is step 3/3 of exporting the model to ONNX. Next steps:
- If there is a missing ONNX function, implement it and register it to the registry.
- If there is an internal error during ONNX conversion, debug the error and submit a PR to PyTorch.
- Create an error report with `torch.onnx.export(..., report=True)`, and save the ExportedProgram as a pt2 file. Create an issue in the PyTorch GitHub repository against the *onnx* component. Attach the error report and the pt2 model.

## Exception summary

<class 'AssertionError'>: SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0
⬆️
<class 'torch.onnx._internal.exporter._errors.GraphConstructionError'>: Error when calling function 'TracedOnnxFunction(<function aten_scaled_dot_product_attention at 0x72c2e1a68860>)' with args '[SymbolicTensor(name='add_173', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s28), 12, SymbolicDim(s11), 64]), producer='node_add_173', index=0), SymbolicTensor(name='add_209', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s28), 12, SymbolicDim(s11), 64]), producer='node_add_209', index=0), SymbolicTensor(name='transpose_3', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s28), 12, SymbolicDim(s11), 64]), producer='node_transpose_3', index=0), SymbolicTensor(name='masked_fill', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s28), 1, SymbolicDim(s11), SymbolicDim(s11)]), producer='node_masked_fill', index=0)]' and kwargs '{'scale': 0.125, 'enable_gqa': True}'
⬆️
<class 'torch.onnx._internal.exporter._errors.ConversionError'>: Error when translating node %scaled_dot_product_attention : [num_users=1] = call_function[target=torch.ops.aten.scaled_dot_product_attention.default](args = (%add_173, %add_209, %transpose_3, %masked_fill), kwargs = {scale: 0.125, enable_gqa: True}). See the stack trace for more information.

(Refer to the full stack trace above for more information.)

I am aware that @tomaarsen is currently involved in bringing EuroBERT support to transformers with this PR but I just wanted to make sure the ONNX export works as well. If I can help, let me know.

-- Vincent

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions