Description:
When attempting to export the google/gemma-4-2b-it model to ONNX using optimum-cli, the export suffers from two critical flaws related to dummy input generation and memory management.
1. Mixed Head Dimensions in Gemma-4 (RoPE Corruption)
Gemma-4 utilizes a heterogeneous architecture where the head_dim varies layer-by-layer (e.g., alternating between 256 and 512).
Currently, optimum.utils.input_generators.GemmaDummyPastKeyValuesGenerator constructs dummy past_key_values using a single, uniform head_dim (usually defaulting to the global_head_dim of 512) for all layers.
Because the dummy shapes do not match the true graph dimensions for the 256-dim layers, the resulting ONNX graph forces a 512-dimension container. This completely breaks the Rotary Position Embedding (RoPE) calculations for the smaller layers (as valid values are rotated against padded zeros), resulting in models that output complete gibberish (e.g., ```( ** **).
Suggested Fix: GemmaDummyPastKeyValuesGenerator.generate() needs to be updated to construct a layer-specific shape map, rather than reusing a single shape tuple for range(self.num_layers).
2. MemoryError on onnx.load(load_external_data=True)
For >2GB models, the torch.onnx.export step successfully completes, writing model.onnx and its external weight chunks to disk. However, the process immediately crashes afterward:
# optimum/exporters/onnx/convert.py (export_pytorch)
onnx_model = onnx.load(str(output), load_external_data=True)
Because the original PyTorch model is still residing in system RAM, forcing Python to load a 5GB ONNX graph + external data directly back into memory for validation causes an immediate MemoryError, crashing the CLI and halting any post-processing.
Suggested Fix: Introduce a CLI flag to skip ONNX validation (e.g., --skip-validation), or explicitly delete/free the PyTorch model from memory before calling onnx.load().
Steps to reproduce the behavior:
- Run
optimum-cli export onnx --model google/gemma-4-2b-it --task text-generation-with-past models/gemma4-onnx
- Observe the
MemoryError crash post-tracing.
- If the model is loaded via ONNX Runtime using shape-elastic padding, observe the broken RoPE text generation due to the uniform dummy dimensions.
Environment info:
optimum version: Latest
transformers version: 4.45.0+
- Platform: Windows ARM64 (Snapdragon X Elite / Samsung S25)
- Execution Provider: CPU / QNN Execution Provider
Hopefully, this gets the maintainers' attention so they can push an official patch for Gemma-4 support! Let me know if you need to adjust any part of it.
Description:
When attempting to export the
google/gemma-4-2b-itmodel to ONNX usingoptimum-cli, the export suffers from two critical flaws related to dummy input generation and memory management.1. Mixed Head Dimensions in Gemma-4 (RoPE Corruption)
Gemma-4 utilizes a heterogeneous architecture where the
head_dimvaries layer-by-layer (e.g., alternating between 256 and 512).Currently,
optimum.utils.input_generators.GemmaDummyPastKeyValuesGeneratorconstructs dummypast_key_valuesusing a single, uniformhead_dim(usually defaulting to theglobal_head_dimof 512) for all layers.Because the dummy shapes do not match the true graph dimensions for the 256-dim layers, the resulting ONNX graph forces a 512-dimension container. This completely breaks the Rotary Position Embedding (RoPE) calculations for the smaller layers (as valid values are rotated against padded zeros), resulting in models that output complete gibberish (e.g.,
```( ** **).Suggested Fix:
GemmaDummyPastKeyValuesGenerator.generate()needs to be updated to construct a layer-specific shape map, rather than reusing a single shape tuple forrange(self.num_layers).2. MemoryError on
onnx.load(load_external_data=True)For >2GB models, the
torch.onnx.exportstep successfully completes, writingmodel.onnxand its external weight chunks to disk. However, the process immediately crashes afterward:Because the original PyTorch model is still residing in system RAM, forcing Python to load a 5GB ONNX graph + external data directly back into memory for validation causes an immediate
MemoryError, crashing the CLI and halting any post-processing.Suggested Fix: Introduce a CLI flag to skip ONNX validation (e.g.,
--skip-validation), or explicitly delete/free the PyTorch model from memory before callingonnx.load().Steps to reproduce the behavior:
optimum-cli export onnx --model google/gemma-4-2b-it --task text-generation-with-past models/gemma4-onnxMemoryErrorcrash post-tracing.Environment info:
optimumversion: Latesttransformersversion: 4.45.0+Hopefully, this gets the maintainers' attention so they can push an official patch for Gemma-4 support! Let me know if you need to adjust any part of it.