@@ -112,6 +112,10 @@ def chunked_causal_mask_function(chunk_size: int) -> Callable:
112112
113113
114114def padding_mask_function (padding_mask : torch .Tensor ) -> Callable :
115+ """
116+ This return the mask_function function corresponding to a 2D padding mask.
117+ """
118+
115119 def inner_mask (batch_idx : int , head_idx : int , q_idx : int , kv_idx : int ) -> bool :
116120 # Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
117121 # we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
@@ -121,6 +125,17 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
121125 return inner_mask
122126
123127
128+ def packed_sequence_mask_function (packed_sequence_mask : torch .Tensor ) -> Callable :
129+ """
130+ This return the mask_function function corresponding to a 2D packed sequence mask.
131+ """
132+
133+ def inner_mask (batch_idx : int , head_idx : int , q_idx : int , kv_idx : int ) -> bool :
134+ return packed_sequence_mask [batch_idx , q_idx ] == packed_sequence_mask [batch_idx , kv_idx ]
135+
136+ return inner_mask
137+
138+
124139def add_offsets_to_mask_function (mask_function : Callable , q_offset : int , kv_offset : int ) -> Callable :
125140 """
126141 This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
@@ -592,12 +607,40 @@ class AttentionMaskInterface(GeneralInterface):
592607ALL_MASK_ATTENTION_FUNCTIONS : AttentionMaskInterface = AttentionMaskInterface ()
593608
594609
610+ def find_packed_sequence_indices (position_ids : torch .Tensor ) -> Optional [torch .Tensor ]:
611+ """
612+ Find the indices of the sequence to which each new query token in the sequence belongs when using packed
613+ tensor format (i.e. several sequences packed in the same batch dimension).
614+
615+ Args:
616+ position_ids (`torch.Tensor`)
617+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
618+
619+ Returns:
620+ A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
621+ pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
622+ """
623+ # What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
624+ # taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
625+ # gives exactly the sequence indices
626+ # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
627+ # cannot be part of the end of the first batch dim and the start of the 2nd one for example
628+ first_dummy_value = position_ids [:, :1 ] - 1 # We just need the diff on this first value to be 1
629+ position_diff = torch .diff (position_ids , prepend = first_dummy_value , dim = - 1 )
630+ packed_sequence_mask = (position_diff != 1 ).cumsum (- 1 )
631+
632+ # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
633+ # but it causes issues with export
634+ return packed_sequence_mask
635+
636+
595637def _preprocess_mask_arguments (
596638 config : PretrainedConfig ,
597639 input_embeds : torch .Tensor ,
598640 attention_mask : Optional [Union [torch .Tensor , BlockMask ]],
599641 cache_position : torch .Tensor ,
600642 past_key_values : Optional [Cache ],
643+ position_ids : Optional [torch .Tensor ],
601644 layer_idx : Optional [int ],
602645) -> tuple [bool , Optional [Union [torch .Tensor , BlockMask ]], int , int ]:
603646 """
@@ -617,6 +660,8 @@ def _preprocess_mask_arguments(
617660 A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
618661 past_key_values (`Cache`, optional):
619662 The past key values, if we use a cache.
663+ position_ids (`torch.Tensor`, optional)
664+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
620665 layer_idx (`int`, optional):
621666 If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
622667 length and offset. Indeed, for hybrid caches, different layers may return different lengths.
@@ -626,22 +671,25 @@ def _preprocess_mask_arguments(
626671 Whether we should early exit mask creation, and return the mask as-is.
627672 attention_mask (`torch.Tensor` or `BlockMask` or `None`):
628673 The attention mask to either return immediately, or to use in downstream mask creation.
674+ packed_sequence_mask (`torch.Tensor`, optional):
675+ In case we detected packed sequence format, this is a tensor where each similar integer indicates that
676+ the tokens belong to the same sequence.
629677 kv_length (`int`):
630678 The size that the key and value states will have during the attention computation.
631679 kv_offset (`int`):
632680 An offset to indicate at which first position the key and values states will refer to.
633681 """
634682 # If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
635683 if isinstance (attention_mask , (torch .Tensor , BlockMask )) and len (attention_mask .shape ) == 4 :
636- return True , attention_mask , None , None
684+ return True , attention_mask , None , None , None
637685
638686 # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
639687 # Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
640688 # full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
641689 # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
642690 # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
643691 if config ._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS ._global_mapping :
644- return True , None , None , None
692+ return True , None , None , None , None
645693
646694 # Move the mask to correct device, and potentially switch dtype for efficiency
647695 if attention_mask is not None and attention_mask .ndim == 2 :
@@ -654,7 +702,17 @@ def _preprocess_mask_arguments(
654702 else :
655703 kv_length , kv_offset = input_embeds .shape [1 ], 0
656704
657- return False , attention_mask , kv_length , kv_offset
705+ # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
706+ # and we don't have past_key_values, i.e. generally a training setup)
707+ packed_sequence_mask = None
708+ if position_ids is not None and attention_mask is None and past_key_values is None :
709+ batch_size = input_embeds .shape [0 ]
710+ # The position ids are sometimes just unsqueezed, without being expanded
711+ if batch_size != position_ids .shape [0 ]:
712+ position_ids = position_ids .expand (batch_size , - 1 )
713+ packed_sequence_mask = find_packed_sequence_indices (position_ids )
714+
715+ return False , attention_mask , packed_sequence_mask , kv_length , kv_offset
658716
659717
660718def create_causal_mask (
@@ -663,6 +721,7 @@ def create_causal_mask(
663721 attention_mask : Optional [torch .Tensor ],
664722 cache_position : torch .Tensor ,
665723 past_key_values : Optional [Cache ],
724+ position_ids : Optional [torch .Tensor ],
666725 or_mask_function : Optional [Callable ] = None ,
667726 and_mask_function : Optional [Callable ] = None ,
668727) -> Optional [Union [torch .Tensor , BlockMask ]]:
@@ -684,6 +743,8 @@ def create_causal_mask(
684743 A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
685744 past_key_values (`Cache`, optional):
686745 The past key values, if we use a cache.
746+ position_ids (`torch.Tensor`, optional)
747+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
687748 or_mask_function (`Callable`, optional):
688749 An optional mask function to combine with the causal mask function (by doing the union of both). This is
689750 useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -697,8 +758,8 @@ def create_causal_mask(
697758 else :
698759 layer_idx = 0
699760
700- early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
701- config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
761+ early_exit , attention_mask , packed_sequence_mask , kv_length , kv_offset = _preprocess_mask_arguments (
762+ config , input_embeds , attention_mask , cache_position , past_key_values , position_ids , layer_idx
702763 )
703764 if early_exit :
704765 return attention_mask
@@ -711,6 +772,11 @@ def create_causal_mask(
711772 # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
712773 allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
713774
775+ # If we detected packing format
776+ if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6 :
777+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
778+ allow_is_causal_skip = False
779+
714780 # Allow slight deviations from causal mask
715781 if or_mask_function is not None :
716782 if not _is_torch_greater_or_equal_than_2_6 :
@@ -744,6 +810,7 @@ def create_sliding_window_causal_mask(
744810 attention_mask : Optional [torch .Tensor ],
745811 cache_position : torch .Tensor ,
746812 past_key_values : Optional [Cache ],
813+ position_ids : Optional [torch .Tensor ],
747814 or_mask_function : Optional [Callable ] = None ,
748815 and_mask_function : Optional [Callable ] = None ,
749816) -> Optional [Union [torch .Tensor , BlockMask ]]:
@@ -766,6 +833,8 @@ def create_sliding_window_causal_mask(
766833 A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
767834 past_key_values (`Cache`, optional):
768835 The past key values, if we use a cache.
836+ position_ids (`torch.Tensor`, optional)
837+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
769838 or_mask_function (`Callable`, optional):
770839 An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
771840 useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
@@ -779,8 +848,8 @@ def create_sliding_window_causal_mask(
779848 else :
780849 layer_idx = 0
781850
782- early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
783- config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
851+ early_exit , attention_mask , packed_sequence_mask , kv_length , kv_offset = _preprocess_mask_arguments (
852+ config , input_embeds , attention_mask , cache_position , past_key_values , position_ids , layer_idx
784853 )
785854 if early_exit :
786855 return attention_mask
@@ -797,6 +866,11 @@ def create_sliding_window_causal_mask(
797866 # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
798867 allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
799868
869+ # If we detected packing format
870+ if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6 :
871+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
872+ allow_is_causal_skip = False
873+
800874 # Allow slight deviations from sliding causal mask
801875 if or_mask_function is not None :
802876 if not _is_torch_greater_or_equal_than_2_6 :
@@ -831,6 +905,7 @@ def create_chunked_causal_mask(
831905 attention_mask : Optional [torch .Tensor ],
832906 cache_position : torch .Tensor ,
833907 past_key_values : Optional [Cache ],
908+ position_ids : Optional [torch .Tensor ],
834909 or_mask_function : Optional [Callable ] = None ,
835910 and_mask_function : Optional [Callable ] = None ,
836911) -> Optional [Union [torch .Tensor , BlockMask ]]:
@@ -853,6 +928,8 @@ def create_chunked_causal_mask(
853928 A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
854929 past_key_values (`Cache`, optional):
855930 The past key values, if we use a cache.
931+ position_ids (`torch.Tensor`, optional)
932+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
856933 or_mask_function (`Callable`, optional):
857934 An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
858935 useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
@@ -866,8 +943,8 @@ def create_chunked_causal_mask(
866943 else :
867944 layer_idx = 0
868945
869- early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
870- config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
946+ early_exit , attention_mask , packed_sequence_mask , kv_length , kv_offset = _preprocess_mask_arguments (
947+ config , input_embeds , attention_mask , cache_position , past_key_values , position_ids , layer_idx
871948 )
872949 if early_exit :
873950 return attention_mask
@@ -891,6 +968,11 @@ def create_chunked_causal_mask(
891968 # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
892969 allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
893970
971+ # If we detected packing format
972+ if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6 :
973+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
974+ allow_is_causal_skip = False
975+
894976 # Allow slight deviations from chunked causal mask
895977 if or_mask_function is not None :
896978 if not _is_torch_greater_or_equal_than_2_6 :
@@ -932,6 +1014,7 @@ def create_masks_for_generate(
9321014 attention_mask : Optional [torch .Tensor ],
9331015 cache_position : torch .Tensor ,
9341016 past_key_values : Optional [Cache ],
1017+ position_ids : Optional [torch .Tensor ],
9351018 or_mask_function : Optional [Callable ] = None ,
9361019 and_mask_function : Optional [Callable ] = None ,
9371020 ** kwargs ,
@@ -953,6 +1036,8 @@ def create_masks_for_generate(
9531036 A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
9541037 past_key_values (`Cache`, optional):
9551038 The past key values, if we use a cache.
1039+ position_ids (`torch.Tensor`, optional)
1040+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
9561041 or_mask_function (`Callable`, optional):
9571042 An optional mask function to combine with the other mask function (by doing the union of both). This is
9581043 useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -969,6 +1054,7 @@ def create_masks_for_generate(
9691054 "attention_mask" : attention_mask ,
9701055 "cache_position" : cache_position ,
9711056 "past_key_values" : past_key_values ,
1057+ "position_ids" : position_ids ,
9721058 "or_mask_function" : or_mask_function ,
9731059 "and_mask_function" : and_mask_function ,
9741060 }
0 commit comments