Skip to content

[Llamma4] Chunked Attention #37351

@vasqu

Description

@vasqu

System Info

  • transformers version: 4.52.0.dev0 (around commit 8bbcdf5)
  • Platform: Linux-6.8.0-111057-tuxedo-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.30.1
  • Safetensors version: 0.4.3
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Tensorflow version (GPU?): 2.15.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.7.0 (cpu)
  • Jax version: 0.4.13
  • JaxLib version: 0.4.13
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 3080 Ti Laptop GPU

Who can help?

@ArthurZucker @winglian (fyi)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Because I'm GPU poor, I modified llama4 to only have one layer and a lower hidden size.

Rough script:

import torch

from transformers import AutoConfig, AutoTokenizer
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM


config = AutoConfig.from_pretrained("meta-llama/Llama-4-Scout-17B-16E-Instruct").get_text_config()

# modify config for debugging
config._attn_implementation = "eager"  # or "flex_attention", possibly also "sdpa"
config.hidden_size = 128
config.num_hidden_layers = 1
config.attention_chunk_size = 3 # causes eager issues, leaving default causes issues in flex attention

# some dummy data
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-4-Scout-17B-16E-Instruct")
tokenizer.padding_side = "left"  # irrelevant tbh
input_text = ["What are we having for dinner?", "How are you?"]
input_ids = tokenizer(input_text, padding=True, return_tensors="pt").to("cuda")

# init module, half precision to save on vram 
test_module = Llama4ForCausalLM(config).to("cuda", torch.bfloat16)

# simple forward pass
test_module.forward(**input_ids)

This can cause various issues, e.g.

  • In eager: RuntimeError: The size of tensor a (8) must match the size of tensor b (2) at non-singleton dimension 0
  • In flex: ValueError: block_mask was created for block_mask.shape=(2, 1, 8, tensor(8192, device='cuda:0')) but got q_len=8 and kv_len=8. (...) - looks like more fixes for post-training llama4 #37329 (comment)

Expected behavior

Chunked attention doesn't seem to be correctly handled atm. A lots of code does not enter this territory because of the fairly long context to even go over the required chunk size.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions