-
Notifications
You must be signed in to change notification settings - Fork 642
Fix batched inference/generation, position_ids creation, falcon alibi, gpt_bigcode multi-query,.. #2326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Fix batched inference/generation, position_ids creation, falcon alibi, gpt_bigcode multi-query,.. #2326
Changes from all commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
63a6efe
test left-padded batched inference
IlyasMoutawwakil 39496d8
demonstrate batched tex generation failure
IlyasMoutawwakil 2ccc150
fix remote code
IlyasMoutawwakil ecf65d5
fix
IlyasMoutawwakil 9f3eedc
fix position_ids generation inside ORTModelForCausalLM class
IlyasMoutawwakil b7bec5e
it works until transformers 4.52 -_-
IlyasMoutawwakil 0df42e5
now run with latest transformers
IlyasMoutawwakil 999a145
bolean 4D mask is actually not supported by torch onnx exporter
IlyasMoutawwakil 638856e
only test generation with batched inputs, for logits are a bit off be…
IlyasMoutawwakil 3d40502
boolean mask safe softmax batched inference
IlyasMoutawwakil 023d2ac
style
IlyasMoutawwakil accf852
use old typing
IlyasMoutawwakil 0965ea9
don't do unnecessary patching
IlyasMoutawwakil d1f9bbd
try to avoid spamming the hub for an image
IlyasMoutawwakil 01c4084
update min transformers version
IlyasMoutawwakil aeeecb2
better and direct torch patching
IlyasMoutawwakil fc62f42
more batched generation special cases
IlyasMoutawwakil ba994fb
style
IlyasMoutawwakil de6a798
initialize the il image instead of downloading it
IlyasMoutawwakil cf164b3
use random pil image
IlyasMoutawwakil 5934bf9
test different versions of transformers in fast tests
IlyasMoutawwakil 4b76f5e
fix
IlyasMoutawwakil e171196
revert diffusers changes for now
IlyasMoutawwakil 5ab88b6
mask padding kv cache as well
IlyasMoutawwakil 4d35600
fix masking for old bloom
IlyasMoutawwakil b2a5f41
use constant image to image loading errors
IlyasMoutawwakil 3f58892
style
IlyasMoutawwakil b9d2e03
test diffusers in series to avoid runner dying
IlyasMoutawwakil bdcc425
fix
IlyasMoutawwakil a3dc4e8
cleanup and some comments
IlyasMoutawwakil a1ff2f2
fix and test falcon alibi
IlyasMoutawwakil 603f62c
style
IlyasMoutawwakil cf5b562
fix, support and test multi_query=False as well
IlyasMoutawwakil 3a29549
only apply masked testing for transformers version previous to 4.39
IlyasMoutawwakil af5fa34
Update optimum/onnxruntime/modeling_decoder.py
IlyasMoutawwakil 59c0c14
use text decoder position ids onnx config but test its sync with list
IlyasMoutawwakil b5d92e5
Merge branch 'fix-ort-batched-generation' of https://github.com/huggi…
IlyasMoutawwakil 9db07bf
fix opt
IlyasMoutawwakil 98123d4
style
IlyasMoutawwakil 411df8f
fix sdpa without overriting torch onnx exporter
IlyasMoutawwakil 133f340
use inplace op ;-;
IlyasMoutawwakil 9044948
Merge branch 'main' into fix-ort-batched-generation
IlyasMoutawwakil c98ab28
fix st test
IlyasMoutawwakil e787b92
patch directly in onnx because patch needs to happen after softmax
IlyasMoutawwakil File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -92,9 +92,7 @@ | |
| from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME | ||
| from .model_patcher import ( | ||
| CLIPModelPatcher, | ||
| FalconModelPatcher, | ||
| MgpstrModelPatcher, | ||
| MistralModelPatcher, | ||
| MusicgenModelPatcher, | ||
| Qwen3MoeModelPatcher, | ||
| SAMModelPatcher, | ||
|
|
@@ -409,20 +407,12 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): | |
| NORMALIZED_CONFIG_CLASS = NormalizedTextConfig | ||
|
|
||
|
|
||
| # OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46 | ||
| if is_transformers_version(">=", "4.46.0"): | ||
|
|
||
| @register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"]) | ||
| class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): | ||
| DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. | ||
| NORMALIZED_CONFIG_CLASS = NormalizedTextConfig | ||
|
|
||
| else: | ||
|
|
||
| @register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"]) | ||
| class OPTOnnxConfig(TextDecoderOnnxConfig): | ||
| DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. | ||
| NORMALIZED_CONFIG_CLASS = NormalizedTextConfig | ||
| @register_tasks_manager_onnx("opt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "question-answering"]) | ||
| class OPTOnnxConfig( | ||
| TextDecoderWithPositionIdsOnnxConfig if is_transformers_version(">=", "4.46.0") else TextDecoderOnnxConfig | ||
| ): | ||
| DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. | ||
| NORMALIZED_CONFIG_CLASS = NormalizedTextConfig | ||
|
|
||
|
|
||
| @register_tasks_manager_onnx("llama", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) | ||
|
|
@@ -477,7 +467,6 @@ class GemmaOnnxConfig(LlamaOnnxConfig): | |
| @register_tasks_manager_onnx("granite", *COMMON_TEXT_GENERATION_TASKS) | ||
| class GraniteOnnxConfig(LlamaOnnxConfig): | ||
| MIN_TRANSFORMERS_VERSION = version.parse("4.45.0") | ||
| MIN_TORCH_VERSION = version.parse("2.5.0") | ||
|
|
||
|
|
||
| @register_tasks_manager_onnx("phi", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) | ||
|
|
@@ -502,17 +491,11 @@ class InternLM2OnnxConfig(LlamaOnnxConfig): | |
|
|
||
| @register_tasks_manager_onnx("mistral", *COMMON_TEXT_GENERATION_TASKS + ["text-classification"]) | ||
| class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): | ||
| # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35 | ||
| MIN_TRANSFORMERS_VERSION = version.parse("4.35.0") | ||
|
|
||
| # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 | ||
| DEFAULT_ONNX_OPSET = 14 | ||
| DUMMY_INPUT_GENERATOR_CLASSES = ( | ||
| MistralDummyPastKeyValuesGenerator, | ||
| ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES | ||
| DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator | ||
| DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) | ||
| NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) | ||
| _MODEL_PATCHER = MistralModelPatcher | ||
|
|
||
|
|
||
| @register_tasks_manager_onnx("mpt", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"]) | ||
|
|
@@ -556,9 +539,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire | |
| "gpt_bigcode", *COMMON_TEXT_GENERATION_TASKS + ["text-classification", "token-classification"] | ||
| ) | ||
| class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): | ||
| DUMMY_INPUT_GENERATOR_CLASSES = ( | ||
| GPTBigCodeDummyPastKeyValuesGenerator, | ||
| ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES | ||
| DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GPTBigCodeDummyPastKeyValuesGenerator) | ||
| DEFAULT_ONNX_OPSET = 14 # GPT BigCode now uses F.scaled_dot_product_attention by default for torch>=2.1.1. | ||
| DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator | ||
| NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("gpt_bigcode") | ||
|
|
@@ -571,36 +552,29 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire | |
| decoder_sequence_name = "past_sequence_length" | ||
| name = "past_key_values" | ||
| else: | ||
| decoder_sequence_name = "past_sequence_length + 1" | ||
| decoder_sequence_name = "past_sequence_length + sequence_length" | ||
| name = "present" | ||
|
|
||
| for i in range(self._normalized_config.num_layers): | ||
| # No dim for `n_head` when using multi-query attention | ||
| inputs_or_outputs[f"{name}.{i}.key_value"] = { | ||
| 0: "batch_size", | ||
| 1: decoder_sequence_name, | ||
| } | ||
| if self._normalized_config.multi_query: | ||
| # No dim for `n_head` when using multi-query attention | ||
| inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 1: decoder_sequence_name} | ||
| else: | ||
| inputs_or_outputs[f"{name}.{i}.key_value"] = {0: "batch_size", 2: decoder_sequence_name} | ||
|
Comment on lines
+559
to
+563
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. support for multi_query=True/False for gpt bigcode |
||
|
|
||
| def flatten_past_key_values(self, flattened_output, name, idx, t): | ||
| flattened_output[f"{name}.{idx}.key_value"] = t | ||
|
|
||
|
|
||
| @register_tasks_manager_onnx("falcon", *COMMON_TEXT_GENERATION_TASKS + ["question-answering", "token-classification"]) | ||
| class FalconOnnxConfig(TextDecoderOnnxConfig): | ||
| # This is due to the cache refactoring for Falcon in 4.36 | ||
| MIN_TRANSFORMERS_VERSION = version.parse("4.35.99") | ||
| class FalconOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): | ||
| MIN_TRANSFORMERS_VERSION = version.parse("4.36.0") | ||
|
|
||
| DUMMY_INPUT_GENERATOR_CLASSES = ( | ||
| FalconDummyPastKeyValuesGenerator, | ||
| ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES | ||
| DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, FalconDummyPastKeyValuesGenerator) | ||
| DEFAULT_ONNX_OPSET = 14 # Falcon uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention | ||
| NORMALIZED_CONFIG_CLASS = NormalizedTextConfig | ||
| DUMMY_PKV_GENERATOR_CLASS = FalconDummyPastKeyValuesGenerator | ||
|
|
||
| # we need to set output_attentions=True in the model input to avoid calling | ||
| # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export | ||
| _MODEL_PATCHER = FalconModelPatcher | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: "PretrainedConfig", | ||
|
|
@@ -634,10 +608,8 @@ def __init__( | |
| def inputs(self) -> Dict[str, Dict[int, str]]: | ||
| common_inputs = super().inputs | ||
|
|
||
| if not self.legacy and not self._config.alibi and self.task in ["text-generation", "feature-extraction"]: | ||
| # When alibi is used, position_ids are not used in Falcon. | ||
| # Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116 | ||
| common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} | ||
| if self._config.alibi: | ||
| common_inputs.pop("position_ids", None) | ||
|
echarlaix marked this conversation as resolved.
|
||
|
|
||
| return common_inputs | ||
|
|
||
|
|
@@ -836,7 +808,6 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): | |
| ) | ||
| class BartOnnxConfig(M2M100OnnxConfig): | ||
| DEFAULT_ONNX_OPSET = 14 # Bart now uses F.scaled_dot_product_attention by default for torch>=2.1.1. | ||
| MIN_TORCH_VERSION = version.parse("2.1.2") | ||
|
|
||
|
|
||
| @register_tasks_manager_onnx( | ||
|
|
@@ -868,7 +839,7 @@ class BigBirdPegasusOnnxConfig(BartOnnxConfig): | |
| @property | ||
| def inputs(self) -> Dict[str, Dict[int, str]]: | ||
| inputs = super().inputs | ||
| if self._config.attention_type == "block_sparse": | ||
| if self._config.attention_type == "block_sparse" and self.task != "text-generation": | ||
| # BigBirdPegasusEncoder creates its own attention_mask internally | ||
| # https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py#L1875 | ||
| inputs.pop("attention_mask", None) | ||
|
|
@@ -888,7 +859,6 @@ class MarianOnnxConfig(BartOnnxConfig): | |
| @register_tasks_manager_onnx("vit", *["feature-extraction", "image-classification", "masked-im"]) | ||
| class ViTOnnxConfig(VisionOnnxConfig): | ||
| NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig | ||
| MIN_TORCH_VERSION = version.parse("1.11") | ||
| DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. | ||
|
|
||
| @property | ||
|
|
@@ -1574,7 +1544,6 @@ class OwlViTOnnxConfig(CLIPOnnxConfig): | |
| # Sets the absolute tolerance to when validating the exported ONNX model against the | ||
| # reference model. | ||
| ATOL_FOR_VALIDATION = 1e-4 | ||
| MIN_TORCH_VERSION = version.parse("2.1") | ||
|
|
||
| # needs einsum operator support, available since opset 12 | ||
| DEFAULT_ONNX_OPSET = 12 | ||
|
|
@@ -1646,7 +1615,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]: | |
| "layoutlmv3", *["feature-extraction", "question-answering", "text-classification", "token-classification"] | ||
| ) | ||
| class LayoutLMv3OnnxConfig(TextAndVisionOnnxConfig): | ||
| MIN_TORCH_VERSION = version.parse("1.12") | ||
| NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( | ||
| allow_new=True, | ||
| MAX_2D_POSITION_EMBEDDINGS="max_2d_position_embeddings", | ||
|
|
@@ -2570,8 +2538,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]: | |
| @register_tasks_manager_onnx("sam", *["feature-extraction"]) | ||
| class SamOnnxConfig(OnnxConfig): | ||
| MIN_TRANSFORMERS_VERSION = version.parse("4.29.0.dev0") | ||
| # Since ransformers 4.32.0, SAM uses repeat_interleave op that is broken in PyTorch 2.0.1: https://github.com/pytorch/pytorch/issues/100429 | ||
| MIN_TORCH_VERSION = version.parse("2.0.99") | ||
| NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig | ||
| DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator, DummyVisionEmbeddingsGenerator) | ||
| DEFAULT_ONNX_OPSET = 13 # Opset 12 for repeat_interleave falls back on the opset 9 implem, that raises Unsupported: ONNX export of repeat_interleave in opset 9. | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was wrong