Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Data2VecVision
- Deberta
- Deberta-v2
- Decision Transformer
- Deit
- Detr
- DistilBert
Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ class OnnxConfig(ExportConfig, ABC):
"end_logits": {0: "batch_size", 1: "sequence_length"},
}
),
"reinforcement-learning": OrderedDict(
{
"state_preds": {0: "batch_size", 1: "sequence_length", 2: "states"},
"action_preds": {0: "batch_size", 1: "sequence_length", 2: "actions"},
"return_preds": {0: "batch_size", 1: "sequence_length", 2: "returns"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length", 2: "last_hidden_state"},
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated
}
),
"semantic-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
"text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}),
"text-classification": OrderedDict({"logits": {0: "batch_size"}}),
Expand Down
19 changes: 19 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union

from packaging import version
from sipbuild.generator.parser.tokens import states
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated
from transformers.utils import is_tf_available

from ...onnx import merge_decoders
Expand All @@ -27,6 +28,7 @@
DummyAudioInputGenerator,
DummyCodegenDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummyDecisionTransformerInputGenerator,
DummyEncodecInputGenerator,
DummyInputGenerator,
DummyIntGenerator,
Expand Down Expand Up @@ -256,6 +258,23 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):
pass


class DecisionTransformerOnnxConfig(GPT2OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyDecisionTransformerInputGenerator,
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:

return {
'states': {0: 'batch_size', 1: 'sequence_length', 2: 'states'},
'actions': {0: 'batch_size', 1: 'sequence_length', 2: 'actions'},
'returns_to_go': {0: 'batch_size', 1: 'sequence_length', 2: 'returns'},
'timesteps': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated
}


class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")
Expand Down
10 changes: 10 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class TasksManager:
"multiple-choice": "AutoModelForMultipleChoice",
"object-detection": "AutoModelForObjectDetection",
"question-answering": "AutoModelForQuestionAnswering",
"reinforcement-learning": "AutoModel",
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
"semantic-segmentation": "AutoModelForSemanticSegmentation",
"text-to-audio": ("AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform"),
"text-generation": "AutoModelForCausalLM",
Expand Down Expand Up @@ -562,6 +563,12 @@ class TasksManager:
onnx="DebertaV2OnnxConfig",
tflite="DebertaV2TFLiteConfig",
),
"decision-transformer": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"reinforcement-learning",
onnx="DecisionTransformerOnnxConfig",
),
"deit": supported_tasks_mapping(
"feature-extraction",
"image-classification",
Expand Down Expand Up @@ -2061,6 +2068,9 @@ def get_model_from_task(
if original_task == "automatic-speech-recognition" or task == "automatic-speech-recognition":
if original_task == "auto" and config.architectures is not None:
model_class_name = config.architectures[0]
elif original_task == "reinforcement-learning" or task == "reinforcement-learning":
if config.architectures is not None:
model_class_name = config.architectures[0]

if library_name == "diffusers":
config = DiffusionPipeline.load_config(model_name_or_path, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
DummyAudioInputGenerator,
DummyBboxInputGenerator,
DummyCodegenDecoderTextInputGenerator,
DummyDecisionTransformerInputGenerator,
DummyDecoderTextInputGenerator,
DummyEncodecInputGenerator,
DummyInputGenerator,
Expand Down
36 changes: 36 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,42 @@ class DummyDecoderTextInputGenerator(DummyTextInputGenerator):
)


class DummyDecisionTransformerInputGenerator(DummyTextInputGenerator):
"""
Generates dummy decision transformer inputs.
"""

SUPPORTED_INPUT_NAMES = (
'actions',
'timesteps',
'attention_mask',
'returns_to_go',
'states',
)

Comment thread
IlyasMoutawwakil marked this conversation as resolved.

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.act_dim = self.normalized_config.config.act_dim
self.state_dim = self.normalized_config.config.state_dim
self.max_ep_len = self.normalized_config.config.max_ep_len

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "states":
shape = [self.batch_size, self.sequence_length, self.state_dim]
elif input_name == "actions":
shape = [self.batch_size, self.sequence_length, self.act_dim]
elif input_name == 'returns_to_go':
shape = [self.batch_size, self.sequence_length, 1]
elif input_name == "attention_mask":
shape = [self.batch_size, self.sequence_length]
elif input_name == 'timesteps':
shape = [self.batch_size, self.sequence_length]
return self.random_int_tensor(shape=shape, max_value=max_ep_len, framework=framework, dtype=int_dtype)
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated

return self.random_float_tensor(shape, min_value=-2., max_value=2., framework=framework, dtype=float_dtype)


class DummySeq2SeqDecoderTextInputGenerator(DummyDecoderTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"decoder_input_ids",
Expand Down
9 changes: 9 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class NormalizedTextConfig(NormalizedConfig):
EOS_TOKEN_ID = "eos_token_id"


class NormalizedDecisionTransformerConfig(NormalizedConfig):
# REFERENCE: https://huggingface.co/docs/transformers/model_doc/decision_transformer
STATE_DIM = "state_dim"
ACT_DIM = "act_dim"
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated
MAX_EP_LEN = "max_ep_len"
HIDDEN_SIZE = "hidden_size"


class NormalizedTextConfigWithGQA(NormalizedTextConfig):
NUM_KEY_VALUE_HEADS = "num_key_value_heads"

Expand Down Expand Up @@ -236,6 +244,7 @@ class NormalizedConfigManager:
"cvt": NormalizedVisionConfig,
"deberta": NormalizedTextConfig,
"deberta-v2": NormalizedTextConfig,
"decision-transformer": NormalizedDecisionTransformerConfig,
"deit": NormalizedVisionConfig,
"distilbert": NormalizedTextConfig.with_args(num_attention_heads="n_heads", hidden_size="dim"),
"donut-swin": NormalizedVisionConfig,
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"data2vec-audio": "hf-internal-testing/tiny-random-Data2VecAudioModel",
"deberta": "hf-internal-testing/tiny-random-DebertaModel",
"deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model",
"decision-transformer": "edbeeching/decision-transformer-gym-hopper-medium",
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
"deit": "hf-internal-testing/tiny-random-DeiTModel",
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
"donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel",
Expand Down