1919 add_start_docstrings_to_model_forward ,
2020 replace_return_docstrings ,
2121)
22+ from ...utils .deprecation import deprecate_kwarg
2223from ..auto import AutoModel , AutoModelForCausalLM
2324from .configuration_new_task_model import NewTaskModelConfig
2425
@@ -254,8 +255,7 @@ def _update_causal_mask(
254255 token_type_ids ,
255256 past_key_values ,
256257 cache_position ,
257- input_ids = None ,
258- inputs_embeds = None ,
258+ input_tensor ,
259259 is_training : bool = False ,
260260 ):
261261 if self .config .text_config ._attn_implementation == "flash_attention_2" :
@@ -265,8 +265,7 @@ def _update_causal_mask(
265265
266266 using_static_cache = isinstance (past_key_values , StaticCache )
267267 min_dtype = torch .finfo (self .dtype ).min
268- inputs_lead_dim = input_ids .shape [0 ] if input_ids is not None else inputs_embeds .shape [0 ]
269- sequence_length = input_ids .shape [1 ] if input_ids is not None else inputs_embeds .shape [1 ]
268+ inputs_lead_dim , sequence_length = input_tensor .shape [:2 ]
270269 if using_static_cache :
271270 target_length = past_key_values .get_max_cache_shape ()
272271 elif isinstance (past_key_values , HybridCache ):
@@ -297,16 +296,20 @@ def _update_causal_mask(
297296 if attention_mask is not None :
298297 causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
299298 mask_length = attention_mask .shape [- 1 ]
299+
300+ # First unmask prefix tokens during training
301+ if is_training :
302+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
303+ token_type_ids [:, None , None , :].to (causal_mask .device ) == 0 , 0
304+ )
305+
306+ # Then apply padding mask (will mask pad tokens)
300307 padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (causal_mask .device )
301308 padding_mask = padding_mask == 0
302309 causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
303310 padding_mask , min_dtype
304311 )
305- # we are training thus we need to create a full mask on the image + prefix but causal on suffix
306- if is_training :
307- causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
308- token_type_ids [:, None , None , :].to (causal_mask .device ) == 0 , 0
309- )
312+
310313 return causal_mask
311314
312315 def get_image_features (self , pixel_values : torch .FloatTensor ):
@@ -325,6 +328,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor):
325328 image_features = image_features / (self .config .text_config .hidden_size ** 0.5 )
326329 return image_features
327330
331+ @deprecate_kwarg ("num_logits_to_keep" , version = "4.50" , new_name = "logits_to_keep" )
328332 @add_start_docstrings_to_model_forward (NEW_TASK_MODEL_INPUTS_DOCSTRING )
329333 @replace_return_docstrings (output_type = NewTaskModelCausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
330334 def forward (
@@ -351,10 +355,12 @@ def forward(
351355 config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
352356 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
353357
354- num_logits_to_keep (`int`, *optional*):
355- Calculate logits for the last `num_logits_to_keep ` tokens. If `0`, calculate logits for all
358+ logits_to_keep (`int` or `torch.Tensor `, *optional*):
359+ If an `int`, compute logits for the last `logits_to_keep ` tokens. If `0`, calculate logits for all
356360 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
357361 token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
362+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
363+ This is useful when using packed tensor format (single dimension for batch and sequence length).
358364
359365 Returns:
360366
@@ -418,7 +424,7 @@ def prepare_inputs_for_generation(
418424 attention_mask = None ,
419425 token_type_ids = None ,
420426 use_cache = True ,
421- num_logits_to_keep = None ,
427+ logits_to_keep = None ,
422428 labels = None ,
423429 ** kwargs ,
424430 ):
@@ -431,7 +437,7 @@ def prepare_inputs_for_generation(
431437 position_ids = position_ids ,
432438 cache_position = cache_position ,
433439 use_cache = use_cache ,
434- num_logits_to_keep = num_logits_to_keep ,
440+ logits_to_keep = logits_to_keep ,
435441 token_type_ids = token_type_ids ,
436442 ** kwargs ,
437443 )
@@ -445,10 +451,12 @@ def prepare_inputs_for_generation(
445451 model_inputs ["pixel_values" ] = pixel_values
446452 is_training = token_type_ids is not None and labels is not None
447453 if cache_position [0 ] == 0 and isinstance (past_key_values , HybridCache ):
454+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
448455 causal_mask = self ._update_causal_mask (
449- attention_mask , token_type_ids , past_key_values , cache_position , input_ids , inputs_embeds , is_training
456+ attention_mask , token_type_ids , past_key_values , cache_position , input_tensor , is_training
450457 )
451458 model_inputs ["attention_mask" ] = causal_mask
459+
452460 return model_inputs
453461
454462 def resize_token_embeddings (
0 commit comments