-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Description
System Info
- transformers version: 4.51.2
- Platform: Linux-6.8.0-1021-aws-x86_64-with-glibc2.35
- Python version: 3.12.8
- Huggingface_hub version: 0.30.2
- Safetensors version: 0.5.3
- Accelerate version: 1.5.2
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: No
- Using GPU in script?: No
- GPU type: NVIDIA L40S
Who can help?
Reproduction
For the Cohere2 and Gemma2 models, if the last_cache_position argument is not supplied at runtime to their Model.forward method, it is created either using the 2D attention mask if supplied, or using the cache position tensor. From the source code:
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
last_cache_position = (
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
)However and by design: attention_mask.shape[-1] is never equal to cache_position[-1].item() but always equal to cache_position[-1].item() + 1.
This can be asserted using the following script:
import torch
def create_cache_position(attention_mask_2d: torch.LongTensor, is_prefill: bool) -> torch.LongTensor:
# From tranformers.utils.GenerationMixin._get_initial_cache_position & _update_model_kwargs_for_generation
cache_position = torch.ones_like(attention_mask_2d[0, :], dtype=torch.int64).cumsum(0) - 1
if is_prefill:
return cache_position
else:
return cache_position[-1:]
def update_2d_attention_mask(attention_mask_2d: torch.LongTensor, padding_side: str) -> torch.LongTensor:
# From tranformers.utils.GenerationMixin._update_model_kwargs_for_generation
batch_size, _ = attention_mask_2d.shape
if padding_side == "left":
attention_mask_2d = torch.cat([attention_mask_2d, attention_mask_2d.new_ones((batch_size, 1))], dim=1)
else:
attention_mask_2d = torch.cat([attention_mask_2d.new_ones((batch_size, 1)), attention_mask_2d], dim=1)
return attention_mask_2d
# PREFILL
attention_mask_2d = torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.int32)
cache_position = create_cache_position(attention_mask_2d, is_prefill=True)
assert attention_mask_2d.shape[-1] == cache_position[-1].item() + 1
# TOKEN GENERATION
attention_mask_2d = update_2d_attention_mask(attention_mask_2d, padding_side="left")
cache_position = create_cache_position(attention_mask_2d, is_prefill=False)
assert attention_mask_2d.shape[-1] == cache_position[-1].item() + 1Expected behavior
Defining last_cache_position = attention_mask.shape[-1] produces the expected behavior (and this is the behavior we get when using the generate API with the Cohere 2 model at least) so we just need to make last_cache_position consistent as follow:
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
last_cache_position = attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() + 1
)However, multiple doctrings and comments in the code describe last_cache_position as being identical to cache_position[-1]. If we choose to define last_cache_position = cache_position[-1]. Then the code above must be adjusted as follows:
if last_cache_position is None:
last_cache_position = 0
if attention_mask is not None:
last_cache_position = attention_mask.shape[-1] - 1 if attention_mask.dim() == 2 else cache_position[-1].item()
)On top of that, the attention mask subsetting operation in the decoder layer's forward method (which is the only place where last_cache_position is being used) must be adjusted to account for this change:
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
# ...
offset = (last_cache_position + 1) - effective_seq_len
offset = max(0, offset)
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]