Skip to content

Fix and test batched inference/generation, position_ids creation, falcon alibi, gpt_bigcode multi-query,..#28

Merged
echarlaix merged 8 commits into
mainfrom
batched-inference
Aug 1, 2025
Merged

Fix and test batched inference/generation, position_ids creation, falcon alibi, gpt_bigcode multi-query,..#28
echarlaix merged 8 commits into
mainfrom
batched-inference

Conversation

@IlyasMoutawwakil
Copy link
Copy Markdown
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Aug 1, 2025

@IlyasMoutawwakil IlyasMoutawwakil changed the title Fix and test batched inference/generation Fix and test batched inference/generation, position_ids creation, falcon alibi, gpt_bigcode multi-query,.. Aug 1, 2025
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

The PR fixes and tests batched inference/generation, position_ids creation, falcon alibi, and gpt_bigcode multi-query attention handling across the ONNX runtime and exporters.

  • Updates decoder modeling to properly handle batched inputs, position_ids generation, and past key values masking
  • Refactors test files to use batched inputs by default and improve test coverage for various model architectures
  • Simplifies CI workflows and removes deprecated model patchers

Reviewed Changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
tests/onnxruntime/testing_utils.py Adds new model variants for falcon-alibi-True and gpt_bigcode-multi_query-False testing
tests/onnxruntime/test_diffusion.py Replaces external image loading with random image generation and standardizes tolerance values
tests/onnxruntime/test_decoder.py Major refactor to use batched inputs, improved tokenizer handling, and better past key values comparison
tests/exporters/onnx/test_export.py Updates custom MPT config to use correct sequence dimension naming
optimum/onnxruntime/modeling_decoder.py Comprehensive fixes for position_ids creation, past key values handling, and model-specific logic
optimum/exporters/onnx/utils.py Adds missing model types to position_ids requirements list
optimum/exporters/onnx/model_patcher.py Removes deprecated model patchers and adds ONNX-compatible implementations
optimum/exporters/onnx/model_configs.py Simplifies model configurations and fixes sequence dimension naming
optimum/exporters/onnx/config.py Fixes attention mask dimension ordering for past key values
optimum/exporters/onnx/base.py Updates sequence dimension naming for past key values
.github/workflows/test_onnxruntime_slow.yml Simplifies CI workflow by removing matrix variations
.github/workflows/test_onnxruntime.yml Adds transformers version matrix testing
Comments suppressed due to low confidence (3)

optimum/exporters/onnx/model_patcher.py:374

  • [nitpick] The function name 'onnx_compatible_triu' doesn't follow the existing naming pattern. Consider renaming to 'patched_triu' to match the pattern used by other patched functions like 'patched_forward'.
def onnx_compatible_triu(input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:

optimum/exporters/onnx/model_patcher.py:366

  • [nitpick] The function name 'onnx_compatible_tril' doesn't follow the existing naming pattern. Consider renaming to 'patched_tril' to match the pattern used by other patched functions like 'patched_forward'.
def onnx_compatible_tril(input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:

optimum/exporters/onnx/model_patcher.py:406

  • [nitpick] The function name 'noop_bfloat16_casting' could be more descriptive. Consider renaming to 'identity_bfloat16_casting' or 'passthrough_bfloat16_casting' to better indicate that it returns the tensor unchanged.
def noop_bfloat16_casting(self):

Comment thread tests/onnxruntime/test_decoder.py Outdated
Copy link
Copy Markdown
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot

@echarlaix echarlaix merged commit 94e0d02 into main Aug 1, 2025
33 checks passed
@echarlaix echarlaix deleted the batched-inference branch August 1, 2025 13:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants