Skip to content

Commit d3e4c2a

Browse files
Cyrilvallezrjgleaton
authored andcommitted
Add packed tensor format support for flex/sdpa/eager through the mask! (huggingface#39194)
* Add the necesary logic to mask_utils * add it everywhere * Update masking_utils.py * style * Update masking_utils.py * Update modeling_mimi.py * Update masking_utils.py * add support for more than batch size 1 * Update masking_utils.py * add test * style * Update test_masking_utils.py * Update masking_utils.py * add require_token * fix tests * fix
1 parent 7f71746 commit d3e4c2a

File tree

65 files changed

+303
-9
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+303
-9
lines changed

src/transformers/generation/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ def prepare_inputs_for_generation(
656656
# If it's not defined, it means the model uses the new general mask API
657657
if causal_mask_creation_function is None: # can't be found
658658
token_type_ids = getattr(model_input, "token_type_ids", None)
659+
position_ids = getattr(model_input, position_ids_key, None)
659660
# Some models may overwrite the general one
660661
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
661662
attention_mask = causal_mask_creation_function(
@@ -665,6 +666,7 @@ def prepare_inputs_for_generation(
665666
attention_mask=attention_mask,
666667
cache_position=cache_position,
667668
past_key_values=past_key_values,
669+
position_ids=position_ids,
668670
token_type_ids=token_type_ids,
669671
)
670672
else:

src/transformers/masking_utils.py

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def chunked_causal_mask_function(chunk_size: int) -> Callable:
112112

113113

114114
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
115+
"""
116+
This return the mask_function function corresponding to a 2D padding mask.
117+
"""
118+
115119
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
116120
# Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
117121
# we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
@@ -121,6 +125,17 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
121125
return inner_mask
122126

123127

128+
def packed_sequence_mask_function(packed_sequence_mask: torch.Tensor) -> Callable:
129+
"""
130+
This return the mask_function function corresponding to a 2D packed sequence mask.
131+
"""
132+
133+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
134+
return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx]
135+
136+
return inner_mask
137+
138+
124139
def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable:
125140
"""
126141
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
@@ -592,12 +607,40 @@ class AttentionMaskInterface(GeneralInterface):
592607
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
593608

594609

610+
def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]:
611+
"""
612+
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
613+
tensor format (i.e. several sequences packed in the same batch dimension).
614+
615+
Args:
616+
position_ids (`torch.Tensor`)
617+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
618+
619+
Returns:
620+
A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
621+
pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
622+
"""
623+
# What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
624+
# taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
625+
# gives exactly the sequence indices
626+
# Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
627+
# cannot be part of the end of the first batch dim and the start of the 2nd one for example
628+
first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
629+
position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
630+
packed_sequence_mask = (position_diff != 1).cumsum(-1)
631+
632+
# Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
633+
# but it causes issues with export
634+
return packed_sequence_mask
635+
636+
595637
def _preprocess_mask_arguments(
596638
config: PretrainedConfig,
597639
input_embeds: torch.Tensor,
598640
attention_mask: Optional[Union[torch.Tensor, BlockMask]],
599641
cache_position: torch.Tensor,
600642
past_key_values: Optional[Cache],
643+
position_ids: Optional[torch.Tensor],
601644
layer_idx: Optional[int],
602645
) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]:
603646
"""
@@ -617,6 +660,8 @@ def _preprocess_mask_arguments(
617660
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
618661
past_key_values (`Cache`, optional):
619662
The past key values, if we use a cache.
663+
position_ids (`torch.Tensor`, optional)
664+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
620665
layer_idx (`int`, optional):
621666
If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
622667
length and offset. Indeed, for hybrid caches, different layers may return different lengths.
@@ -626,22 +671,25 @@ def _preprocess_mask_arguments(
626671
Whether we should early exit mask creation, and return the mask as-is.
627672
attention_mask (`torch.Tensor` or `BlockMask` or `None`):
628673
The attention mask to either return immediately, or to use in downstream mask creation.
674+
packed_sequence_mask (`torch.Tensor`, optional):
675+
In case we detected packed sequence format, this is a tensor where each similar integer indicates that
676+
the tokens belong to the same sequence.
629677
kv_length (`int`):
630678
The size that the key and value states will have during the attention computation.
631679
kv_offset (`int`):
632680
An offset to indicate at which first position the key and values states will refer to.
633681
"""
634682
# If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
635683
if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4:
636-
return True, attention_mask, None, None
684+
return True, attention_mask, None, None, None
637685

638686
# For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
639687
# Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
640688
# full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
641689
# with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
642690
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
643691
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
644-
return True, None, None, None
692+
return True, None, None, None, None
645693

646694
# Move the mask to correct device, and potentially switch dtype for efficiency
647695
if attention_mask is not None and attention_mask.ndim == 2:
@@ -654,7 +702,17 @@ def _preprocess_mask_arguments(
654702
else:
655703
kv_length, kv_offset = input_embeds.shape[1], 0
656704

657-
return False, attention_mask, kv_length, kv_offset
705+
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
706+
# and we don't have past_key_values, i.e. generally a training setup)
707+
packed_sequence_mask = None
708+
if position_ids is not None and attention_mask is None and past_key_values is None:
709+
batch_size = input_embeds.shape[0]
710+
# The position ids are sometimes just unsqueezed, without being expanded
711+
if batch_size != position_ids.shape[0]:
712+
position_ids = position_ids.expand(batch_size, -1)
713+
packed_sequence_mask = find_packed_sequence_indices(position_ids)
714+
715+
return False, attention_mask, packed_sequence_mask, kv_length, kv_offset
658716

659717

660718
def create_causal_mask(
@@ -663,6 +721,7 @@ def create_causal_mask(
663721
attention_mask: Optional[torch.Tensor],
664722
cache_position: torch.Tensor,
665723
past_key_values: Optional[Cache],
724+
position_ids: Optional[torch.Tensor],
666725
or_mask_function: Optional[Callable] = None,
667726
and_mask_function: Optional[Callable] = None,
668727
) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -684,6 +743,8 @@ def create_causal_mask(
684743
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
685744
past_key_values (`Cache`, optional):
686745
The past key values, if we use a cache.
746+
position_ids (`torch.Tensor`, optional)
747+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
687748
or_mask_function (`Callable`, optional):
688749
An optional mask function to combine with the causal mask function (by doing the union of both). This is
689750
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -697,8 +758,8 @@ def create_causal_mask(
697758
else:
698759
layer_idx = 0
699760

700-
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
701-
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
761+
early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
762+
config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
702763
)
703764
if early_exit:
704765
return attention_mask
@@ -711,6 +772,11 @@ def create_causal_mask(
711772
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
712773
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
713774

775+
# If we detected packing format
776+
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
777+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
778+
allow_is_causal_skip = False
779+
714780
# Allow slight deviations from causal mask
715781
if or_mask_function is not None:
716782
if not _is_torch_greater_or_equal_than_2_6:
@@ -744,6 +810,7 @@ def create_sliding_window_causal_mask(
744810
attention_mask: Optional[torch.Tensor],
745811
cache_position: torch.Tensor,
746812
past_key_values: Optional[Cache],
813+
position_ids: Optional[torch.Tensor],
747814
or_mask_function: Optional[Callable] = None,
748815
and_mask_function: Optional[Callable] = None,
749816
) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -766,6 +833,8 @@ def create_sliding_window_causal_mask(
766833
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
767834
past_key_values (`Cache`, optional):
768835
The past key values, if we use a cache.
836+
position_ids (`torch.Tensor`, optional)
837+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
769838
or_mask_function (`Callable`, optional):
770839
An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
771840
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
@@ -779,8 +848,8 @@ def create_sliding_window_causal_mask(
779848
else:
780849
layer_idx = 0
781850

782-
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
783-
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
851+
early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
852+
config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
784853
)
785854
if early_exit:
786855
return attention_mask
@@ -797,6 +866,11 @@ def create_sliding_window_causal_mask(
797866
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
798867
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
799868

869+
# If we detected packing format
870+
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
871+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
872+
allow_is_causal_skip = False
873+
800874
# Allow slight deviations from sliding causal mask
801875
if or_mask_function is not None:
802876
if not _is_torch_greater_or_equal_than_2_6:
@@ -831,6 +905,7 @@ def create_chunked_causal_mask(
831905
attention_mask: Optional[torch.Tensor],
832906
cache_position: torch.Tensor,
833907
past_key_values: Optional[Cache],
908+
position_ids: Optional[torch.Tensor],
834909
or_mask_function: Optional[Callable] = None,
835910
and_mask_function: Optional[Callable] = None,
836911
) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -853,6 +928,8 @@ def create_chunked_causal_mask(
853928
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
854929
past_key_values (`Cache`, optional):
855930
The past key values, if we use a cache.
931+
position_ids (`torch.Tensor`, optional)
932+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
856933
or_mask_function (`Callable`, optional):
857934
An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
858935
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
@@ -866,8 +943,8 @@ def create_chunked_causal_mask(
866943
else:
867944
layer_idx = 0
868945

869-
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
870-
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
946+
early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
947+
config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
871948
)
872949
if early_exit:
873950
return attention_mask
@@ -891,6 +968,11 @@ def create_chunked_causal_mask(
891968
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
892969
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
893970

971+
# If we detected packing format
972+
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
973+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
974+
allow_is_causal_skip = False
975+
894976
# Allow slight deviations from chunked causal mask
895977
if or_mask_function is not None:
896978
if not _is_torch_greater_or_equal_than_2_6:
@@ -932,6 +1014,7 @@ def create_masks_for_generate(
9321014
attention_mask: Optional[torch.Tensor],
9331015
cache_position: torch.Tensor,
9341016
past_key_values: Optional[Cache],
1017+
position_ids: Optional[torch.Tensor],
9351018
or_mask_function: Optional[Callable] = None,
9361019
and_mask_function: Optional[Callable] = None,
9371020
**kwargs,
@@ -953,6 +1036,8 @@ def create_masks_for_generate(
9531036
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
9541037
past_key_values (`Cache`, optional):
9551038
The past key values, if we use a cache.
1039+
position_ids (`torch.Tensor`, optional)
1040+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
9561041
or_mask_function (`Callable`, optional):
9571042
An optional mask function to combine with the other mask function (by doing the union of both). This is
9581043
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -969,6 +1054,7 @@ def create_masks_for_generate(
9691054
"attention_mask": attention_mask,
9701055
"cache_position": cache_position,
9711056
"past_key_values": past_key_values,
1057+
"position_ids": position_ids,
9721058
"or_mask_function": or_mask_function,
9731059
"and_mask_function": and_mask_function,
9741060
}

src/transformers/models/arcee/modeling_arcee.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def forward(
423423
attention_mask=attention_mask,
424424
cache_position=cache_position,
425425
past_key_values=past_key_values,
426+
position_ids=position_ids,
426427
)
427428

428429
hidden_states = inputs_embeds

src/transformers/models/aria/modeling_aria.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,7 @@ def forward(
806806
attention_mask=attention_mask,
807807
cache_position=cache_position,
808808
past_key_values=past_key_values,
809+
position_ids=position_ids,
809810
)
810811

811812
hidden_states = inputs_embeds

src/transformers/models/bitnet/modeling_bitnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def forward(
420420
attention_mask=attention_mask,
421421
cache_position=cache_position,
422422
past_key_values=past_key_values,
423+
position_ids=position_ids,
423424
)
424425

425426
hidden_states = inputs_embeds

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def forward(
457457
attention_mask=attention_mask,
458458
cache_position=cache_position,
459459
past_key_values=past_key_values,
460+
position_ids=position_ids,
460461
)
461462

462463
hidden_states = inputs_embeds

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def forward(
434434
"attention_mask": attention_mask,
435435
"cache_position": cache_position,
436436
"past_key_values": past_key_values,
437+
"position_ids": position_ids,
437438
}
438439
# Create the masks
439440
causal_mask_mapping = {

src/transformers/models/cohere2/modular_cohere2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def forward(
455455
"attention_mask": attention_mask,
456456
"cache_position": cache_position,
457457
"past_key_values": past_key_values,
458+
"position_ids": position_ids,
458459
}
459460
# Create the masks
460461
causal_mask_mapping = {

src/transformers/models/csm/modeling_csm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def forward(
500500
attention_mask=attention_mask,
501501
cache_position=cache_position,
502502
past_key_values=past_key_values,
503+
position_ids=position_ids,
503504
)
504505

505506
hidden_states = inputs_embeds
@@ -811,6 +812,7 @@ def forward(
811812
attention_mask=attention_mask,
812813
cache_position=cache_position,
813814
past_key_values=past_key_values,
815+
position_ids=position_ids,
814816
)
815817

816818
hidden_states = inputs_embeds

src/transformers/models/csm/modular_csm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def forward(
238238
attention_mask=attention_mask,
239239
cache_position=cache_position,
240240
past_key_values=past_key_values,
241+
position_ids=position_ids,
241242
)
242243

243244
hidden_states = inputs_embeds

0 commit comments

Comments
 (0)