@@ -112,6 +112,10 @@ def chunked_causal_mask_function(chunk_size: int) -> Callable:
112
112
113
113
114
114
def padding_mask_function (padding_mask : torch .Tensor ) -> Callable :
115
+ """
116
+ This return the mask_function function corresponding to a 2D padding mask.
117
+ """
118
+
115
119
def inner_mask (batch_idx : int , head_idx : int , q_idx : int , kv_idx : int ) -> bool :
116
120
# Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
117
121
# we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
@@ -121,6 +125,17 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
121
125
return inner_mask
122
126
123
127
128
+ def packed_sequence_mask_function (packed_sequence_mask : torch .Tensor ) -> Callable :
129
+ """
130
+ This return the mask_function function corresponding to a 2D packed sequence mask.
131
+ """
132
+
133
+ def inner_mask (batch_idx : int , head_idx : int , q_idx : int , kv_idx : int ) -> bool :
134
+ return packed_sequence_mask [batch_idx , q_idx ] == packed_sequence_mask [batch_idx , kv_idx ]
135
+
136
+ return inner_mask
137
+
138
+
124
139
def add_offsets_to_mask_function (mask_function : Callable , q_offset : int , kv_offset : int ) -> Callable :
125
140
"""
126
141
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
@@ -592,12 +607,40 @@ class AttentionMaskInterface(GeneralInterface):
592
607
ALL_MASK_ATTENTION_FUNCTIONS : AttentionMaskInterface = AttentionMaskInterface ()
593
608
594
609
610
+ def find_packed_sequence_indices (position_ids : torch .Tensor ) -> Optional [torch .Tensor ]:
611
+ """
612
+ Find the indices of the sequence to which each new query token in the sequence belongs when using packed
613
+ tensor format (i.e. several sequences packed in the same batch dimension).
614
+
615
+ Args:
616
+ position_ids (`torch.Tensor`)
617
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
618
+
619
+ Returns:
620
+ A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
621
+ pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
622
+ """
623
+ # What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
624
+ # taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
625
+ # gives exactly the sequence indices
626
+ # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
627
+ # cannot be part of the end of the first batch dim and the start of the 2nd one for example
628
+ first_dummy_value = position_ids [:, :1 ] - 1 # We just need the diff on this first value to be 1
629
+ position_diff = torch .diff (position_ids , prepend = first_dummy_value , dim = - 1 )
630
+ packed_sequence_mask = (position_diff != 1 ).cumsum (- 1 )
631
+
632
+ # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
633
+ # but it causes issues with export
634
+ return packed_sequence_mask
635
+
636
+
595
637
def _preprocess_mask_arguments (
596
638
config : PretrainedConfig ,
597
639
input_embeds : torch .Tensor ,
598
640
attention_mask : Optional [Union [torch .Tensor , BlockMask ]],
599
641
cache_position : torch .Tensor ,
600
642
past_key_values : Optional [Cache ],
643
+ position_ids : Optional [torch .Tensor ],
601
644
layer_idx : Optional [int ],
602
645
) -> tuple [bool , Optional [Union [torch .Tensor , BlockMask ]], int , int ]:
603
646
"""
@@ -617,6 +660,8 @@ def _preprocess_mask_arguments(
617
660
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
618
661
past_key_values (`Cache`, optional):
619
662
The past key values, if we use a cache.
663
+ position_ids (`torch.Tensor`, optional)
664
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
620
665
layer_idx (`int`, optional):
621
666
If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
622
667
length and offset. Indeed, for hybrid caches, different layers may return different lengths.
@@ -626,22 +671,25 @@ def _preprocess_mask_arguments(
626
671
Whether we should early exit mask creation, and return the mask as-is.
627
672
attention_mask (`torch.Tensor` or `BlockMask` or `None`):
628
673
The attention mask to either return immediately, or to use in downstream mask creation.
674
+ packed_sequence_mask (`torch.Tensor`, optional):
675
+ In case we detected packed sequence format, this is a tensor where each similar integer indicates that
676
+ the tokens belong to the same sequence.
629
677
kv_length (`int`):
630
678
The size that the key and value states will have during the attention computation.
631
679
kv_offset (`int`):
632
680
An offset to indicate at which first position the key and values states will refer to.
633
681
"""
634
682
# If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
635
683
if isinstance (attention_mask , (torch .Tensor , BlockMask )) and len (attention_mask .shape ) == 4 :
636
- return True , attention_mask , None , None
684
+ return True , attention_mask , None , None , None
637
685
638
686
# For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
639
687
# Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
640
688
# full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
641
689
# with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
642
690
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
643
691
if config ._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS ._global_mapping :
644
- return True , None , None , None
692
+ return True , None , None , None , None
645
693
646
694
# Move the mask to correct device, and potentially switch dtype for efficiency
647
695
if attention_mask is not None and attention_mask .ndim == 2 :
@@ -654,7 +702,17 @@ def _preprocess_mask_arguments(
654
702
else :
655
703
kv_length , kv_offset = input_embeds .shape [1 ], 0
656
704
657
- return False , attention_mask , kv_length , kv_offset
705
+ # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
706
+ # and we don't have past_key_values, i.e. generally a training setup)
707
+ packed_sequence_mask = None
708
+ if position_ids is not None and attention_mask is None and past_key_values is None :
709
+ batch_size = input_embeds .shape [0 ]
710
+ # The position ids are sometimes just unsqueezed, without being expanded
711
+ if batch_size != position_ids .shape [0 ]:
712
+ position_ids = position_ids .expand (batch_size , - 1 )
713
+ packed_sequence_mask = find_packed_sequence_indices (position_ids )
714
+
715
+ return False , attention_mask , packed_sequence_mask , kv_length , kv_offset
658
716
659
717
660
718
def create_causal_mask (
@@ -663,6 +721,7 @@ def create_causal_mask(
663
721
attention_mask : Optional [torch .Tensor ],
664
722
cache_position : torch .Tensor ,
665
723
past_key_values : Optional [Cache ],
724
+ position_ids : Optional [torch .Tensor ],
666
725
or_mask_function : Optional [Callable ] = None ,
667
726
and_mask_function : Optional [Callable ] = None ,
668
727
) -> Optional [Union [torch .Tensor , BlockMask ]]:
@@ -684,6 +743,8 @@ def create_causal_mask(
684
743
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
685
744
past_key_values (`Cache`, optional):
686
745
The past key values, if we use a cache.
746
+ position_ids (`torch.Tensor`, optional)
747
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
687
748
or_mask_function (`Callable`, optional):
688
749
An optional mask function to combine with the causal mask function (by doing the union of both). This is
689
750
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -697,8 +758,8 @@ def create_causal_mask(
697
758
else :
698
759
layer_idx = 0
699
760
700
- early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
701
- config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
761
+ early_exit , attention_mask , packed_sequence_mask , kv_length , kv_offset = _preprocess_mask_arguments (
762
+ config , input_embeds , attention_mask , cache_position , past_key_values , position_ids , layer_idx
702
763
)
703
764
if early_exit :
704
765
return attention_mask
@@ -711,6 +772,11 @@ def create_causal_mask(
711
772
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
712
773
allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
713
774
775
+ # If we detected packing format
776
+ if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6 :
777
+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
778
+ allow_is_causal_skip = False
779
+
714
780
# Allow slight deviations from causal mask
715
781
if or_mask_function is not None :
716
782
if not _is_torch_greater_or_equal_than_2_6 :
@@ -744,6 +810,7 @@ def create_sliding_window_causal_mask(
744
810
attention_mask : Optional [torch .Tensor ],
745
811
cache_position : torch .Tensor ,
746
812
past_key_values : Optional [Cache ],
813
+ position_ids : Optional [torch .Tensor ],
747
814
or_mask_function : Optional [Callable ] = None ,
748
815
and_mask_function : Optional [Callable ] = None ,
749
816
) -> Optional [Union [torch .Tensor , BlockMask ]]:
@@ -766,6 +833,8 @@ def create_sliding_window_causal_mask(
766
833
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
767
834
past_key_values (`Cache`, optional):
768
835
The past key values, if we use a cache.
836
+ position_ids (`torch.Tensor`, optional)
837
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
769
838
or_mask_function (`Callable`, optional):
770
839
An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
771
840
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
@@ -779,8 +848,8 @@ def create_sliding_window_causal_mask(
779
848
else :
780
849
layer_idx = 0
781
850
782
- early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
783
- config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
851
+ early_exit , attention_mask , packed_sequence_mask , kv_length , kv_offset = _preprocess_mask_arguments (
852
+ config , input_embeds , attention_mask , cache_position , past_key_values , position_ids , layer_idx
784
853
)
785
854
if early_exit :
786
855
return attention_mask
@@ -797,6 +866,11 @@ def create_sliding_window_causal_mask(
797
866
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
798
867
allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
799
868
869
+ # If we detected packing format
870
+ if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6 :
871
+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
872
+ allow_is_causal_skip = False
873
+
800
874
# Allow slight deviations from sliding causal mask
801
875
if or_mask_function is not None :
802
876
if not _is_torch_greater_or_equal_than_2_6 :
@@ -831,6 +905,7 @@ def create_chunked_causal_mask(
831
905
attention_mask : Optional [torch .Tensor ],
832
906
cache_position : torch .Tensor ,
833
907
past_key_values : Optional [Cache ],
908
+ position_ids : Optional [torch .Tensor ],
834
909
or_mask_function : Optional [Callable ] = None ,
835
910
and_mask_function : Optional [Callable ] = None ,
836
911
) -> Optional [Union [torch .Tensor , BlockMask ]]:
@@ -853,6 +928,8 @@ def create_chunked_causal_mask(
853
928
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
854
929
past_key_values (`Cache`, optional):
855
930
The past key values, if we use a cache.
931
+ position_ids (`torch.Tensor`, optional)
932
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
856
933
or_mask_function (`Callable`, optional):
857
934
An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
858
935
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
@@ -866,8 +943,8 @@ def create_chunked_causal_mask(
866
943
else :
867
944
layer_idx = 0
868
945
869
- early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
870
- config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
946
+ early_exit , attention_mask , packed_sequence_mask , kv_length , kv_offset = _preprocess_mask_arguments (
947
+ config , input_embeds , attention_mask , cache_position , past_key_values , position_ids , layer_idx
871
948
)
872
949
if early_exit :
873
950
return attention_mask
@@ -891,6 +968,11 @@ def create_chunked_causal_mask(
891
968
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
892
969
allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
893
970
971
+ # If we detected packing format
972
+ if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6 :
973
+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
974
+ allow_is_causal_skip = False
975
+
894
976
# Allow slight deviations from chunked causal mask
895
977
if or_mask_function is not None :
896
978
if not _is_torch_greater_or_equal_than_2_6 :
@@ -932,6 +1014,7 @@ def create_masks_for_generate(
932
1014
attention_mask : Optional [torch .Tensor ],
933
1015
cache_position : torch .Tensor ,
934
1016
past_key_values : Optional [Cache ],
1017
+ position_ids : Optional [torch .Tensor ],
935
1018
or_mask_function : Optional [Callable ] = None ,
936
1019
and_mask_function : Optional [Callable ] = None ,
937
1020
** kwargs ,
@@ -953,6 +1036,8 @@ def create_masks_for_generate(
953
1036
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
954
1037
past_key_values (`Cache`, optional):
955
1038
The past key values, if we use a cache.
1039
+ position_ids (`torch.Tensor`, optional)
1040
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
956
1041
or_mask_function (`Callable`, optional):
957
1042
An optional mask function to combine with the other mask function (by doing the union of both). This is
958
1043
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -969,6 +1054,7 @@ def create_masks_for_generate(
969
1054
"attention_mask" : attention_mask ,
970
1055
"cache_position" : cache_position ,
971
1056
"past_key_values" : past_key_values ,
1057
+ "position_ids" : position_ids ,
972
1058
"or_mask_function" : or_mask_function ,
973
1059
"and_mask_function" : and_mask_function ,
974
1060
}
0 commit comments