Skip to content

last_cache_position definition issue in hybrid SWA models #37706

@plienhar

Description

@plienhar

System Info

- transformers version: 4.51.2
- Platform: Linux-6.8.0-1021-aws-x86_64-with-glibc2.35
- Python version: 3.12.8
- Huggingface_hub version: 0.30.2
- Safetensors version: 0.5.3
- Accelerate version: 1.5.2
- Accelerate config:    not found
- DeepSpeed version: not installed
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: No
- Using GPU in script?: No
- GPU type: NVIDIA L40S

Who can help?

@gante

Reproduction

For the Cohere2 and Gemma2 models, if the last_cache_position argument is not supplied at runtime to their Model.forward method, it is created either using the 2D attention mask if supplied, or using the cache position tensor. From the source code:

if last_cache_position is None:
    last_cache_position = 0
    if attention_mask is not None:
        last_cache_position = (
            attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
        )

However and by design: attention_mask.shape[-1] is never equal to cache_position[-1].item() but always equal to cache_position[-1].item() + 1.

This can be asserted using the following script:

import torch

def create_cache_position(attention_mask_2d: torch.LongTensor, is_prefill: bool) -> torch.LongTensor:
    # From tranformers.utils.GenerationMixin._get_initial_cache_position & _update_model_kwargs_for_generation
    cache_position = torch.ones_like(attention_mask_2d[0, :], dtype=torch.int64).cumsum(0) - 1
    if is_prefill:
        return cache_position
    else:
        return cache_position[-1:]
    

def update_2d_attention_mask(attention_mask_2d: torch.LongTensor, padding_side: str) -> torch.LongTensor:
    # From tranformers.utils.GenerationMixin._update_model_kwargs_for_generation
    batch_size, _ = attention_mask_2d.shape
    if padding_side == "left":
        attention_mask_2d = torch.cat([attention_mask_2d, attention_mask_2d.new_ones((batch_size, 1))], dim=1)
    else:
        attention_mask_2d = torch.cat([attention_mask_2d.new_ones((batch_size, 1)), attention_mask_2d], dim=1)
    return attention_mask_2d

# PREFILL
attention_mask_2d = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.int32)
cache_position = create_cache_position(attention_mask_2d, is_prefill=True)

assert attention_mask_2d.shape[-1] == cache_position[-1].item() + 1

# TOKEN GENERATION
attention_mask_2d = update_2d_attention_mask(attention_mask_2d, padding_side="left")
cache_position = create_cache_position(attention_mask_2d, is_prefill=False)

assert attention_mask_2d.shape[-1] == cache_position[-1].item() + 1

Expected behavior

Defining last_cache_position = attention_mask.shape[-1] produces the expected behavior (and this is the behavior we get when using the generate API with the Cohere 2 model at least) so we just need to make last_cache_position consistent as follow:

if last_cache_position is None:
    last_cache_position = 0
    if attention_mask is not None:
        last_cache_position = attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + 1 
        )

However, multiple doctrings and comments in the code describe last_cache_position as being identical to cache_position[-1]. If we choose to define last_cache_position = cache_position[-1]. Then the code above must be adjusted as follows:

if last_cache_position is None:
    last_cache_position = 0
    if attention_mask is not None:
        last_cache_position = attention_mask.shape[-1] - 1 if attention_mask.dim() == 2 else cache_position[-1].item()
        )

On top of that, the attention mask subsetting operation in the decoder layer's forward method (which is the only place where last_cache_position is being used) must be adjusted to account for this change:

effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# ...
offset = (last_cache_position + 1) - effective_seq_len
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions