-
Notifications
You must be signed in to change notification settings - Fork 712
Description
🐛 Describe the bug
When running Gemma 3 models (specifically Gemma3-270m and Gemma-3-270m-it) exported to the ExecuTorch format with the optimum-cli, the models start to produce ill-formed or garbage output after reaching its sliding window token limit of 512 tokens. The tests were conducted using the Optimum-Executorch interface, and the models were fed token by token, utilizing the KV cache.
Steps to reproduce
The tested model (gemma-3-270m-it) was exported with the following command:
optimum-cli export executorch \
--model "google/gemma-3-270m-it" \
--task "text-generation" \
--recipe "xnnpack" \
--use_custom_sdpa \
--use_custom_kv_cache \
--qlinear 8da4w \
--qembedding 8w \
--output_dir="gemma3_270" \
--max_seq_len 2048
And the test code itself (the chat template is downloaded from the Gemma repository)
from optimum.executorch import ExecuTorchModelForCausalLM
from transformers import GemmaTokenizer
from pathlib import Path
from jinja2 import Template
tokenizer = GemmaTokenizer.from_pretrained("google/gemma-3-270m-it")
model = ExecuTorchModelForCausalLM.from_pretrained(
Path("./gemma3_270")
)
with open("./gemma3_270/chat_template.jinja") as f:
template = Template(f.read())
messages = [
{"role": "user", "content": "What is love? Answer in at least 600 words."},
]
prompt = template.render(messages=messages, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")
out = model.generate(inputs["input_ids"][0].tolist())
print(tokenizer.decode(out, skip_special_tokens=False))
Actual behavior
The output look reasonable until the model is fed with at least 512 tokens, which corresponds to the gemma-3-270m sliding window limit. Then the model starts producing looped phrases, which eventually turn into garbage. The model produces its output up until hitting the max sentence length of 2048, without producing any eos token before.
The sample log of the output is available here. For better readability, I added full input prompt after applying chat template as well as marked critical section after the 512th processed token.
What I tried
- Different combinations of export parameters and flags (within the optimum-cli export command) - none of them resulted in a proper behavior of the models.
Versions
Collecting environment information...
PyTorch version: 2.10.0.dev20250916
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.7.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.3.19.1)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.11 (main, Jun 3 2025, 15:41:47) [Clang 17.0.0 (clang-1700.0.13.3)] (64-bit runtime)
Python platform: macOS-15.7.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] executorch==1.0.0.dev20250916
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.3.4
[pip3] optimum-executorch==0.2.0.dev0
[pip3] pytorch_tokenizers==0.1.0
[pip3] torch==2.10.0.dev20250916
[pip3] torchao==0.14.0.dev20250916+cpu
[pip3] torchaudio==2.8.0.dev20250916
[pip3] torchvision==0.25.0.dev20250916
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
Type
Projects
Status