Skip to content
Merged
2 changes: 2 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def prepare_inputs_for_generation(
# If it's not defined, it means the model uses the new general mask API
if causal_mask_creation_function is None: # can't be found
token_type_ids = getattr(model_input, "token_type_ids", None)
position_ids = getattr(model_input, position_ids_key, None)
# Some models may overwrite the general one
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
attention_mask = causal_mask_creation_function(
Expand All @@ -665,6 +666,7 @@ def prepare_inputs_for_generation(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
token_type_ids=token_type_ids,
)
else:
Expand Down
104 changes: 95 additions & 9 deletions src/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def chunked_causal_mask_function(chunk_size: int) -> Callable:


def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
"""
This return the mask_function function corresponding to a 2D padding mask.
"""

def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
# we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
Expand All @@ -121,6 +125,17 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
return inner_mask


def packed_sequence_mask_function(packed_sequence_mask: torch.Tensor) -> Callable:
"""
This return the mask_function function corresponding to a 2D packed sequence mask.
"""

def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx]

return inner_mask


def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
Expand Down Expand Up @@ -592,12 +607,40 @@ class AttentionMaskInterface(GeneralInterface):
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()


def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]:
"""
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
tensor format (i.e. several sequences packed in the same batch dimension).

Args:
position_ids (`torch.Tensor`)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.

Returns:
A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
"""
# What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
# taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
# gives exactly the sequence indices
# Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
# cannot be part of the end of the first batch dim and the start of the 2nd one for example
first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
packed_sequence_mask = (position_diff != 1).cumsum(-1)

# Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
# but it causes issues with export
return packed_sequence_mask


def _preprocess_mask_arguments(
config: PretrainedConfig,
input_embeds: torch.Tensor,
attention_mask: Optional[Union[torch.Tensor, BlockMask]],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
layer_idx: Optional[int],
) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]:
"""
Expand All @@ -617,6 +660,8 @@ def _preprocess_mask_arguments(
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`torch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
layer_idx (`int`, optional):
If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
length and offset. Indeed, for hybrid caches, different layers may return different lengths.
Expand All @@ -626,22 +671,25 @@ def _preprocess_mask_arguments(
Whether we should early exit mask creation, and return the mask as-is.
attention_mask (`torch.Tensor` or `BlockMask` or `None`):
The attention mask to either return immediately, or to use in downstream mask creation.
packed_sequence_mask (`torch.Tensor`, optional):
In case we detected packed sequence format, this is a tensor where each similar integer indicates that
the tokens belong to the same sequence.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
kv_offset (`int`):
An offset to indicate at which first position the key and values states will refer to.
"""
# If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4:
return True, attention_mask, None, None
return True, attention_mask, None, None, None

# For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
# Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
# full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
# with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
return True, None, None, None
return True, None, None, None, None

# Move the mask to correct device, and potentially switch dtype for efficiency
if attention_mask is not None and attention_mask.ndim == 2:
Expand All @@ -654,7 +702,17 @@ def _preprocess_mask_arguments(
else:
kv_length, kv_offset = input_embeds.shape[1], 0

return False, attention_mask, kv_length, kv_offset
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
# and we don't have past_key_values, i.e. generally a training setup)
packed_sequence_mask = None
if position_ids is not None and attention_mask is None and past_key_values is None:
batch_size = input_embeds.shape[0]
# The position ids are sometimes just unsqueezed, without being expanded
if batch_size != position_ids.shape[0]:
position_ids = position_ids.expand(batch_size, -1)
packed_sequence_mask = find_packed_sequence_indices(position_ids)

return False, attention_mask, packed_sequence_mask, kv_length, kv_offset


def create_causal_mask(
Expand All @@ -663,6 +721,7 @@ def create_causal_mask(
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
Expand All @@ -684,6 +743,8 @@ def create_causal_mask(
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`torch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
Expand All @@ -697,8 +758,8 @@ def create_causal_mask(
else:
layer_idx = 0

early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
)
if early_exit:
return attention_mask
Expand All @@ -711,6 +772,11 @@ def create_causal_mask(
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True

# If we detected packing format
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
allow_is_causal_skip = False

# Allow slight deviations from causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
Expand Down Expand Up @@ -744,6 +810,7 @@ def create_sliding_window_causal_mask(
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
Expand All @@ -766,6 +833,8 @@ def create_sliding_window_causal_mask(
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`torch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
Expand All @@ -779,8 +848,8 @@ def create_sliding_window_causal_mask(
else:
layer_idx = 0

early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
)
if early_exit:
return attention_mask
Expand All @@ -797,6 +866,11 @@ def create_sliding_window_causal_mask(
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True

# If we detected packing format
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
allow_is_causal_skip = False

# Allow slight deviations from sliding causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
Expand Down Expand Up @@ -831,6 +905,7 @@ def create_chunked_causal_mask(
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
Expand All @@ -853,6 +928,8 @@ def create_chunked_causal_mask(
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`torch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
Expand All @@ -866,8 +943,8 @@ def create_chunked_causal_mask(
else:
layer_idx = 0

early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
)
if early_exit:
return attention_mask
Expand All @@ -891,6 +968,11 @@ def create_chunked_causal_mask(
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True

# If we detected packing format
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
allow_is_causal_skip = False

# Allow slight deviations from chunked causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
Expand Down Expand Up @@ -932,6 +1014,7 @@ def create_masks_for_generate(
attention_mask: Optional[torch.Tensor],
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
position_ids: Optional[torch.Tensor],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
**kwargs,
Expand All @@ -953,6 +1036,8 @@ def create_masks_for_generate(
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`torch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the other mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
Expand All @@ -969,6 +1054,7 @@ def create_masks_for_generate(
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
"or_mask_function": or_mask_function,
"and_mask_function": and_mask_function,
}
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/arcee/modeling_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bitnet/modeling_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def forward(
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def forward(
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/csm/modeling_csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
Expand Down Expand Up @@ -811,6 +812,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/csm/modular_csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dia/modeling_dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dia/modular_dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ def forward(
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dots1/modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def forward(
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
Expand Down
Loading